From 876bc47c67cdaa12ee2a82c17e3ca27aa6f9d00f Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 23 Apr 2026 11:16:55 +0200 Subject: [PATCH 1/9] [CK_TILE] Grouped Convolution Backward Data Direct Load (#6624) ## Proposed changes Add Grouped Convolution Backward Data with Direct Load into DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 device implementation. This enables direct global memory loading (bypassing LDS) for the backward data convolution path on gfx950, following the same pattern used in both backward weight and forward convolution. Direct load convolution backward data improves performance by avoiding LDS round-trips for certain configurations on gfx950, which supports a wider range of instructions. Currently correctness is checked only at usage point, but should be extended to a standalone UT in the future. --- .../blockwise_gemm_pipeline_xdlops_base.hpp | 9 +- ...lockwise_gemm_pipeline_xdlops_selector.hpp | 10 +- .../blockwise_gemm_pipeline_xdlops_v1.hpp | 15 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp | 1216 +++++++++++++++++ ...rouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 19 +- .../gridwise_gemm_xdl_cshuffle_conv_v3.hpp | 326 +++-- ..._grouped_conv_bwd_data_xdl_v3_instance.hpp | 85 ++ .../gpu/grouped_convolution_backward_data.hpp | 4 + .../grouped_convolution_backward_data_xdl.inc | 28 + .../grouped_conv2d_bwd_data/CMakeLists.txt | 2 + ...xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 49 + ..._xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 49 + 12 files changed, 1676 insertions(+), 136 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 4b9b6e076e..abff9de535 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -33,8 +33,9 @@ template + bool TransposeC = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlops_pipeline_base { static constexpr auto I0 = Number<0>{}; @@ -389,7 +390,7 @@ struct BlockwiseGemmXdlops_pipeline_base Sequence<1, 1, 1, KPack>, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoadToVgpr ? 1 : A_K1, + ALdsScalarLoadToVgpr ? 1 : A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoadToVgpr ? 1 : B_K1, + BLdsScalarLoadToVgpr ? 1 : B_K1, B_K1>; AThreadCopy a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp index 461ca513f9..f1a093a7a8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -32,12 +32,13 @@ template + bool DirectLoad = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> constexpr auto BlockGemmPipeline_Selector() { // Supported for Direct Load and V1 - if constexpr(LdsScalarLoadToVgpr) + if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr) { static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1); } @@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, - LdsScalarLoadToVgpr>{}; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index 723ef9cd1e..6c5b2a266b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -747,7 +747,8 @@ template + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 { }; @@ -772,7 +773,8 @@ template + bool ALdsScalarLoadToVgpr, + bool BLdsScalarLoadToVgpr> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr> : BlockwiseGemmXdlops_pipeline_base + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr> { using Base = BlockwiseGemmXdlops_pipeline_base; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 0000000000..270d4e264c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,1216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const std::array gemm_kernel_args, + const index_t gemms_count, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if defined(__gfx9__) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + if constexpr(GridwiseGemm::DirectLoadEnabled) + { +#if defined(__gfx950__) + const auto a_grid_desc_ak0_m_ak1_transformed = + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } +#endif + } + else + { + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + } +#else + ignore = karg; + ignore = gemm_kernel_args; + ignore = gemms_count; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; + +#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +} +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + static_assert(std::is_same_v, "A not NGHWC"); + static_assert(std::is_same_v, "B not GKYXC"); + static_assert(std::is_same_v, "C not NGHWK"); + + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(NumDTensor == 0, "Not supported"); + // static_assert(DirectLoad, "Not supported"); + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = false; + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + // Dummy function just used to create an alias to Grid Descriptors + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + conv_to_gemm_transform.MakeCDescriptor_M_N(), 1, 1); + + return make_tuple(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock); + } + + static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = + ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 + ? 4 / sizeof(ADataType) + : ABlockTransferSrcScalarPerVector; + static constexpr index_t BBlockTransferSrcScalarPerVectorAligned = + BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8 + ? 4 / sizeof(BDataType) + : BBlockTransferSrcScalarPerVector; + + static constexpr bool ALdsScalarLoadToVgpr = false; + static constexpr bool BLdsScalarLoadToVgpr = true; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + EDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeType, + BComputeType, + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // Note: the dummy function is used just to create the alias + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using EGridDesc_MPerBlock_NBlock_NPerBlock = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + EGridDesc_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + e_grid_desc_mblock_mperblock_nblock_nperblock_( + e_grid_desc_mblock_mperblock_nblock_nperblock), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; + // block-to-e-tile map for elementwise kernels + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array&, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bool if_d_is_output_mem = false; + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + + // Temporary workaround untill prove/fix above conditions. + bwd_needs_zero_out = !if_d_is_output_mem; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes, + k_batch_}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + + const auto a_grid_desc_ak0_m_ak1 = [&]() { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + }(); + + const auto b_grid_desc_bk0_n_bk1 = [&]() { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + }(); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + const auto GemmM = a_grid_desc_m_k.GetLength(I0); + const auto GemmN = b_grid_desc_n_k.GetLength(I0); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + conv_to_gemm_transform_.MakeCDescriptor_M_N(), + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + e_grid_desc_mblock_mperblock_nblock_nperblock); + + const index_t grid_size_grp = + std::get<0>(GridwiseGemm::CalculateGridSize(GemmM, GemmN, 1, 1)); + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + // const index_t GemmM = a_grid_desc_m_k.GetLength(I0); + // const index_t GemmN = b_grid_desc_n_k.GetLength(I0); + const index_t GemmK = a_grid_desc_m_k.GetLength(I1); + + // onst auto MBlock = GridwiseGemmCTranspose::CalculateMBlock(GemmM); + // onst auto NBlock = GridwiseGemmCTranspose::CalculateNBlock(GemmN); + + index_t k_grain = split_k * KPerBlock; + index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; + + const bool HasMainKBlockLoop = + GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum] + [gemms_count_ % MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) + { + gemms_grid_size_.push_back(grid_size); + grid_size = 0; + } + } + } + } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; + } + + std::size_t GetWorkspaceSizeBytes() const { return 0; } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) + { + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; + + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t conv_N_per_block_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array e_g_n_c_wis_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t gdy = arg.num_group_; + const index_t gdz = arg.k_batch_; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) + { + const index_t GemmM = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I1); + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && gemm_set_id == 0) + { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop_.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MPerBlock_NBlock_NPerBlock, + MaxGroupedGemmGroupsNum, + GemmArgs, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop>; + + return launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + gemm_kernel_args, + gemms_count_for_set, + arg.compute_ptr_offset_of_batch_, + 1); + }; + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if constexpr(DirectLoad) + { + if(get_device_name() != "gfx950") + { + return false; + } + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ConvBwdDataSpecialization is unsupported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // Check gridwise gemm validity + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + // Create gemm arguments with dummy values to check for validity + typename GridwiseGemm::Argument gemm_arg{nullptr, // p_as_grid + nullptr, // p_bs_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + I0, // StrideAs + I0, // StrideBs + I0, // StrideE + arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / AK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3" + << (DirectLoad ? "_DirectLoad" : "") + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerXdl << ", " + << NPerXdl << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle; + + str << ">"; + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index dade0515af..58de8dd3dc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 ? 4 / sizeof(BDataType) : BBlockTransferSrcScalarPerVector; + static constexpr bool ALdsScalarLoadToVgpr = + (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false); + static constexpr bool BLdsScalarLoadToVgpr = + (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false); + + // Note: Direct load use layout to create proper block and mmtile descriptor + // TODO: Fix and verify RC layout for not direct load (currently it returns wrong results) template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< - tensor_layout::gemm::RowMajor, - tensor_layout::gemm::ColumnMajor, + std::conditional_t, + std::conditional_t, tensor_layout::gemm::RowMajor, ADataType, BDataType, @@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - DirectLoad>; + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index a0fca218d4..c134d34161 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -66,7 +66,9 @@ template + bool DirectLoad = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct GridwiseGemm_xdl_cshuffle_conv_v3 : public GridwiseGemm_xdl_cshuffle_base< ALayout, @@ -249,19 +251,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 return math::integer_divide_ceil(N, NPerBlock); } - template + template + __host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc) + { + + if constexpr(!DirectLoad) + { + return desc; + } + else + { + const index_t K = desc.GetLength(I0) * desc.GetLength(I2); + const index_t MN = desc.GetLength(I1); + + const auto desc_unmerged = transform_tensor_descriptor( + desc, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)), + make_pass_through_transform(MN), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto desc_permuted = transform_tensor_descriptor( + desc_unmerged, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MN, K0Number)), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)), + make_pass_through_transform(MN), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) { - constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); - constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + if constexpr(DirectLoad && IsKContinous) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - return transform_tensor_descriptor( - TileDesc_K0_MN_K1{}, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + + constexpr auto desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + desc, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + else + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } } template @@ -270,7 +343,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return MakeGemmMmaTileDescriptor::value>( + ABlockDesc_AK0_M_AK1{}); } template @@ -279,7 +356,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + return MakeGemmMmaTileDescriptor::value>( + BBlockDesc_BK0_N_BK1{}); } struct Problem @@ -366,9 +447,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { if constexpr(DirectLoad) { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{}, I1, Number{})); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{}, I1, Number{})); + } } else if constexpr(is_same_v) { @@ -389,9 +479,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { if constexpr(DirectLoad) { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{}, I1, Number{})); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{}, I1, Number{})); + } } else if constexpr(is_same_v) { @@ -410,34 +509,35 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // Disable vector load from lds to vgpr for direct load (backward weight store with continous M // or N dimension) - static constexpr bool LdsScalarLoadToVgpr = DirectLoad; - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - AccDataType, - decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), - decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), - decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( + // static constexpr bool LdsScalarLoadToVgpr = DirectLoad; + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + AccDataType, + decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())), + decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))), - decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack, - DirectLoad, - LdsScalarLoadToVgpr>())>; + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>())>; template __device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) @@ -517,8 +617,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1) + const index_t k_id = 0, + const index_t k_batch = 1, + const index_t block_idx_x = static_cast(blockIdx.x)) { const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; @@ -535,8 +636,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( - make_multi_index(static_cast(blockIdx.x))); + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x)); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, @@ -570,23 +671,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType, + decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + ABlockTransferSrcScalarPerVector > + (a_grid_desc_ak0_m_ak1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); } else { @@ -626,23 +723,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_b_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + BBlockTransferSrcScalarPerVector > + (b_grid_desc_bk0_n_bk1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); } else { @@ -750,8 +843,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1) + const index_t k_id = 0, + const index_t k_batch = 1, + const index_t block_idx_x = static_cast(blockIdx.x)) { const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; @@ -771,7 +865,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( - make_multi_index(static_cast(blockIdx.x))); + make_multi_index(static_cast(block_idx_x))); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, @@ -805,23 +899,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType, + decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + ABlockTransferSrcScalarPerVector > + (a_grid_desc_ak0_m_ak1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); } else { @@ -861,23 +951,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_b_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + BBlockTransferSrcScalarPerVector > + (b_grid_desc_bk0_n_bk1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); } else { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp new file mode 100644 index 0000000000..e2b0cb74ba --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true> + + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true> + + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index f784b6ea51..09301474f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( op_ptrs); @@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 7c61f3ee66..8dae166dd1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( #endif #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 19e27cf173..7f2363affd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -32,6 +32,8 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..4d434cc390 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 0000000000..9d1fb4b93a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 90ca12f14ff0d3985053d2e2a0d83ae4c7b17321 Mon Sep 17 00:00:00 2001 From: KateJu <153474223+kateju12@users.noreply.github.com> Date: Thu, 23 Apr 2026 22:05:33 +0800 Subject: [PATCH 2/9] Add missing lds sync (#6655) ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp | 1 + .../device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp | 1 + 2 files changed, 2 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp index 9532f7e76a..87be350a44 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_wmma_fixed_nk.hpp @@ -162,6 +162,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) id_off += grid_size_grp; id_local += grid_size_grp; + block_sync_lds(); } } #else diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 9978b62b17..fa33e0fdea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -136,6 +136,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) id_off += grid_size_grp; id_local += grid_size_grp; + block_sync_lds(); } } #else From c86c0f89b41d5c41316084aaadd2f140310ac341 Mon Sep 17 00:00:00 2001 From: KateJu <153474223+kateju12@users.noreply.github.com> Date: Thu, 23 Apr 2026 22:08:50 +0800 Subject: [PATCH 3/9] Fix per-layer conv2d int8 CPU verification reference path (#6656) case example_conv2d_fwd_xdl_perlayer_quantization_int8.exe 1 0 ## Motivation ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- ...n_conv2d_fwd_perlayer_quantization_example.inc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index 02228d7654..26c3165446 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -108,28 +108,35 @@ bool run_grouped_conv_fwd(bool do_verification, if(do_verification) { + Tensor c_host(out_g_n_k_wos_desc); + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + PassThrough>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, wei, - out_host, + c_host, conv_param.conv_filter_strides_, conv_param.conv_filter_dilations_, conv_param.input_left_pads_, conv_param.input_right_pads_, in_element_op, wei_element_op, - out_element_op); + PassThrough{}); ref_invoker.Run(ref_argument); + out_host.ForEach([&](auto&, auto idx) + { + out_element_op(out_host(idx), c_host(idx)); + }); + out_device_buf.FromDevice(out_device.mData.data()); pass &= From 16a9ced35da5ce5a15ca3daf8aad3f0ba9f5e43c Mon Sep 17 00:00:00 2001 From: Artem Kuzmitckii Date: Thu, 23 Apr 2026 22:10:46 +0200 Subject: [PATCH 4/9] [CK] Fix divide-by-zero crash for grouped conv kernels (#6132) ## Motivation During run pytorch unit tests for conv3d: `test_dtypes_nn_functional_conv3d_cuda`, `test_fake_crossref_backward_amp_nn_functional_conv3d_cuda_float32` found divide-by-zero crash during CK kernel selection. Refs ROCM-20764 ## Technical Details Add assert for K0PerBlock equal 0, also covered other potential places related with k_batch calculation. ## Test Plan Run miopen command extracted from mentioned test: `MIOpenDriver convfp16 --spatial_dim 3 -I NCDHW -O NCDHW -f NCDHW -n 1 -c 1 -k 1 -g 1 --in_d 4 -H 4 -W 4 --fil_d 4 -y 4 -x 4 --pad_d 0 -p 0 -q 0 --conv_stride_d 2 -u 2 -v 2 --dilation_d 1 -l 1 -j 1 -m conv -F 4 -t 1` ## Test Result Passed ## Submission Checklist - [X] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Signed-off-by: Artem Kuzmitckii --- .../device/impl/device_grouped_conv_bwd_weight_dl.hpp | 2 ++ .../impl/device_grouped_conv_bwd_weight_explicit.hpp | 1 + ...ed_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 1 + ...rouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 1 + ...ped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp | 1 + ...grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 1 + ...evice_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 1 + .../device_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 1 + ...device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 1 + .../gpu/device/impl/split_k_utils.hpp | 11 +++++++++++ .../transform_conv_bwd_weight_to_gemm.hpp | 4 ++++ .../transform_conv_bwd_weight_to_gemm_v2.hpp | 5 +++++ 12 files changed, 30 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 19a7536685..88c2207e09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -19,6 +19,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #ifdef CK_EXPERIMENTAL_BUILDER #include "ck_tile/builder/reflect/description.hpp" @@ -856,6 +857,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index a811d2f44a..172a53d652 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -179,6 +179,7 @@ struct DeviceGroupedConvBwdWeight_Explicit k_batch_ = split_k; } } + k_batch_ = clamp_gemm_k_batch(k_batch_); if constexpr(IsTwoStageNeeded) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index a3eab579e7..ed0378e23f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -670,6 +670,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 1e23fef191..ff0616481f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -695,6 +695,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 87117be4ce..bc44cf2bb3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -611,6 +611,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); const auto descs = conv_to_gemm_transformer_v2 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 0ee5ac3647..011bb068f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -717,6 +717,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create initial descriptors with hack=false to check compactness const auto descs_initial = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index bfc88753a2..66fb526641 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -555,6 +555,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 46a9009f83..fef81b281a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -669,6 +669,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes are divisible by k_batch diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 58de8dd3dc..07c8e02514 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -638,6 +638,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 { k_batch_ = split_k; } + k_batch_ = clamp_gemm_k_batch(k_batch_); // Create descriptors first (with hack flags temporarily set to false) // so we can check if element space sizes match product of dimensions diff --git a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp index 3a3bacd945..ea5b282ed1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/split_k_utils.hpp @@ -13,6 +13,13 @@ namespace ck { namespace tensor_operation { namespace device { +/// Ensures GemmKBatch in conv to GEMM transforms is never 0 (would zero the divisor in +/// integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch)). +inline constexpr index_t clamp_gemm_k_batch(index_t k_batch) noexcept +{ + return k_batch < 1 ? index_t{1} : k_batch; +} + struct DeviceProperties { DeviceProperties() @@ -33,6 +40,10 @@ inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index const int max_capacity = max_occupancy * device_properties.num_cu_; ck::index_t k_batch = 1; + if(grid_size <= 0) + { + return k_batch; + } const auto optimal_split = static_cast(std::floor((1.0 * max_capacity) / grid_size)); if(optimal_split > 1) diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index 3379fb2c59..74ec0af7d5 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -21,6 +21,10 @@ template struct TransformConvBwdWeightToGemm { + // Same contract as TransformConvBwdWeightToGemmV2 (non-zero K tile factors). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index 94eae555e9..eeef3e736e 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -31,6 +31,11 @@ template struct TransformConvBwdWeightToGemmV2 { + // Compile-time contract: divisor GemmK1Number * K0PerBlock * GemmKBatch in + // integer_divide_ceil(GemmKTotal, ...) must stay non-zero (GemmKBatch clamped at runtime). + static_assert(GemmK1Number > 0, "GemmK1Number must be positive"); + static_assert(K0PerBlock > 0, "K0PerBlock must be positive"); + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; From 6946b07408f2114fb0966a89a1f54cdfae4031b9 Mon Sep 17 00:00:00 2001 From: Luo Cheng Date: Fri, 24 Apr 2026 06:44:37 +0800 Subject: [PATCH 5/9] [CK] Fix out of bounds modifications caused by negative topk_ids in MoeSortingMultiPhaseKernel_P0_v1 (#6242) ## Motivation Fix sglang randomly crash by filter negative topk ids. ## Technical Details In sglang expert parallel mode, there may be idle batch (batch=0) fired, it will reuse batch=1 resource in cuda graph mode. But in topk op, it will set non used topk ids to -1, in idle batch case, all topk ids are set to -1. In `MoeSortingMultiPhaseKernel_P0_v1` negative expert id will cause overwrite somewhere and sglang may randomly crash. Except idle batch case, if the captured batch sizes are discrete, there may be -1 of expert id due to the similar logic. ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: zovonoir --- include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 07eda483d2..7d766bbe67 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -1685,7 +1685,7 @@ struct MoeSortingMultiPhaseKernel_P0_v1 IndexType eid = x[j.value]; // ext_vector_type must use int to [] uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - if(eid < kargs.num_experts) + if(eid < kargs.num_experts && eid >= 0) { if constexpr(Problem::LocalToken) { From 34084aeb11393f27818667b16e472e72bea20c1f Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 24 Apr 2026 07:08:41 +0800 Subject: [PATCH 6/9] [CK_TILE] fix(fmha): support >2GB KV cache in batch prefill via template dispatch (#6653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The CK batch prefill kernel previously failed (silent overflow + page faults) when the KV cache exceeded 2 GB, blocking long-context inference workloads (e.g., 128K+ token contexts with paged KV). Two distinct failure modes were addressed: 1. **>4GB SRD overflow (`page_size < kN0`):** The SRD `buffer_load_dwordx4` path uses a 32-bit `voffset` register; for small page sizes the rebased SRD spans the full KV pool and the offset wraps past 2 GB, corrupting K/V loads. 2. **gfx950 page-table fault (`page_size >= kN0`):** On CDNA4 the hardware validates the **full SRD `num_records` range** against page-table permissions (CDNA3 only checks per-instruction `voffset`). After per-tile SRD rebase, an un-trimmed `num_records` field extends past the live page and faults on freed/protected memory. ## Technical Details **Two-mode `tile_scatter_gather` selected by the `kUseGlobalLoad` template parameter:** | Case | `page_size` | KV cache size | Mode | Load path | Addressing | |---|---|---|---|---|---| | 1 | `>= kN0` (large pages) | any | SRD (`kUseGlobalLoad=false`) | `buffer_load_dwordx4` | 32-bit `voffset`, bounded by per-page rebase | | 2 | `< kN0` (small pages) | `<= 2 GB` | SRD (`kUseGlobalLoad=false`) | `buffer_load_dwordx4` | 32-bit `voffset`, fits in INT32 byte range | | 3 | `< kN0` (small pages) | `> 2 GB` | Global-load (`kUseGlobalLoad=true`) | `async_load_tile_raw_flat` (K) + `load_tile_flat` (V) | 64-bit | **Dispatch:** the auto-gen API layer (`fmha_batch_prefill.py`) selects the kernel instantiation at launch from `(page_block_size, num_total_pages * batch_stride_k * kElementBytes)`, so the small-page penalty is paid only when correctness requires it. **gfx950 SRD `num_records` trimming:** in the K and V rebase lambdas of `block_fmha_batch_prefill_pipeline_qr_ks_vs_async`, `set_bottom_tensor_view_buffer_size(page_stride_k/v)` is called after each rebase to constrain `num_records` to the live page. Required for CDNA4 page-table validation; harmless on CDNA3. **Pipeline sync for the global-load path:** - V uses synchronous `load_tile_flat`; K uses `async_load_tile_raw_flat`. - `v_physical_pages_current` is double-buffered so the V flat load doesn't race against the next iteration's K rebase computation. **Arch guards:** `global_load_lds` intrinsics are gated to `__gfx94__` / `__gfx950__` (CDNA3+). Other architectures hit a `dependent_false` static_assert with a descriptive message. **Device-side assertion convention:** SRD setters use `__builtin_assume(cond)` (hint-only) rather than ``'s `assert()`. The latter introduces an `__assert_fail` call whose register pressure scatters the K-SRD scalar register window across conditional branches, corrupting `buffer_load_dwordx4` on gfx950. ## Test Plan Tested on both MI308 (gfx942) and MI355 (gfx950) via the aiter wrapper test suite. All coverage lives in **`op_tests/test_batch_prefill.py`**: - **Functional matrix (96 cases)** — `test_batch_prefill`: `page_size ∈ {1, 16, 1024}` × `kv_layout ∈ {linear, vectorized}` × `dtype ∈ {bf16, fp8 quant variants}` × `causal` × `soft_cap` × `LSE` × `batch_size ∈ {1, 4}` (parametrized to exercise per-sequence SRD rebase across batch boundaries). - **>2 GB coverage** — `test_batch_prefill_large_kvcache`: extended to allocate a 5 GB+ KV cache pool and exercise both `kUseGlobalLoad=true` (small-page) and `kUseGlobalLoad=false` (large-page rebase) paths. Includes both single-batch and multi-batch (`batch_size=4`) cases to exercise per-sequence SRD rebase across the >2 GB pool. - Numerical reference: PyTorch SDPA, per-batch loop with `atol` / `rtol` from the existing batch prefill test harness. ## Test Result | Arch | `test_batch_prefill` | `test_batch_prefill_large_kvcache` (>2 GB) | |------|----------------------|---------------------| | MI308 (gfx942) | All passed | Passed | | MI355 (gfx950) | All passed | Passed | **Performance impact (gfx950, hot SRD path):** - +2.67% kernel-time on `seqlen=1024 / page_sz=1024 / bf16 / sglang / causal / soft_cap=30`, attributable in full to the two `set_bottom_tensor_view_buffer_size` calls in the K/V rebase lambdas (5-run median, signal/noise ≈ 9×). - This cost is **mandatory for gfx950 correctness** on >2 GB workloads — removing the setters re-introduces page-faults. - gfx942: 0 regressions in the same range (all configs ≤ +0.97%). ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 80 +++++- example/ck_tile/01_fmha/fmha_fwd.hpp | 32 ++- .../arch/amd_buffer_addressing_builtins.hpp | 81 ++++++ .../core/tensor/tile_scatter_gather.hpp | 177 ++++++++++++- include/ck_tile/core/utility/type_traits.hpp | 14 + include/ck_tile/ops/fmha.hpp | 1 + .../block_attention_kv_load_mode_enum.hpp | 17 ++ ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 242 +++++++++++------- .../pipeline/block_fmha_pipeline_problem.hpp | 6 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 6 +- 10 files changed, 540 insertions(+), 116 deletions(-) create mode 100644 include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 7c3efb9c18..8c006c09db 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -22,8 +22,16 @@ from codegen.cpp_symbol_map import ( QSCALE_CHECK_MAP, QSCALE_MAP, ) +from codegen.arch import ArchTrait from codegen.utils import update_file +# Architecture trait for kernels requiring global_load_lds (CDNA3+). +# Only used for GLOBAL_LOAD_LDS variants; all other kernels are arch-agnostic. +CDNA3_PLUS_ARCH = ArchTrait( + "cdna3_plus", + preprocessor_check="defined(__gfx94__) || defined(__gfx950__)", +) + DTYPE_BITS = { "fp32": 32, "fp16": 16, @@ -34,6 +42,10 @@ DTYPE_BITS = { "bf8": 8, } +# Element size in bytes per dtype, used by the auto-generated dispatcher to +# decide kv_load_mode per-arm (total KV cache bytes vs INT32_MAX). +DTYPE_BYTES = {k: v // 8 for k, v in DTYPE_BITS.items()} + K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256} SUPPORTED_PAGE_SIZE = [1, 16, 1024] @@ -47,6 +59,10 @@ KV_LOOKUP_TABLE_ENUM_MAP = { "vllm": "ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D", "sglang": "ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D", } +KV_LOAD_MODE_ENUM_MAP = { + False: "ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD", + True: "ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS", +} FMHA_BATCH_PREFILL_PIPELINE_MAP = { @@ -61,6 +77,8 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT """ FMHA_FWD_KERNEL_BODY = """ +#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) + using fmha_dtype_{F_idx} = {F_dtype}; using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -87,7 +105,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_sink}, {F_page_size}, {F_kv_memory_layout}, - {F_kv_lookup_table}>; + {F_kv_lookup_table}, + {F_kv_load_mode}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -125,7 +144,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; #include @@ -140,10 +159,13 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} + +#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch_check}) """ FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp" FMHA_FWD_API = """ +#include #include namespace {{ @@ -194,6 +216,7 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, """ FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ + constexpr int kElementBytes = {F_element_bytes}; {F_hdim_case} }} """ @@ -203,8 +226,8 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v """ FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size}) && (fmha_batch_prefill_select_kv_load_mode(a.page_block_size, {F_bn0}, a.num_total_pages, a.batch_stride_k, kElementBytes) == {F_kv_load_mode})) {{ + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}, {F_kv_load_mode}>; return fmha_batch_prefill_(s, a); }} """ @@ -253,12 +276,14 @@ class FmhaFwdApiTrait: kv_memory_layout: str kv_lookup_table: str page_size: int = 1 # page block size + use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def name(self) -> str: return ( f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.kv_memory_layout}-{self.kv_lookup_table}-ps{self.page_size}" + + ("-gload" if self.use_global_load else "-bload") ) @property @@ -481,6 +506,7 @@ class FmhaFwdApiPool: ], F_page_size=trait.page_size, F_sink=BOOL_MAP[trait.sink], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[trait.use_global_load], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -488,7 +514,10 @@ class FmhaFwdApiPool: ) if_i = "if" if i == 0 else "else if" per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format( - F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case + F_if=if_i, + F_dtype=dtype, + F_element_bytes=DTYPE_BYTES[dtype], + F_hdim_case=per_hdim_case, ) if not per_dtypes: # empty string we add some ignore to suppress warning in api @@ -539,6 +568,7 @@ class FmhaFwdKernel: F_pipeline: FmhaFwdPipeline mask_impl: str F_page_size: int = 1 # page block size + F_use_global_load: bool = False # use global_load_lds_* for >2GB KV cache @property def template(self) -> str: @@ -588,6 +618,10 @@ class FmhaFwdKernel: F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, F_sink=BOOL_MAP[self.F_pipeline.F_sink], + F_kv_load_mode=KV_LOAD_MODE_ENUM_MAP[self.F_use_global_load], + F_arch_check=CDNA3_PLUS_ARCH.preprocessor_check + if self.F_use_global_load + else "true", ) @property @@ -595,6 +629,7 @@ class FmhaFwdKernel: # TODO: we don't encode idx here return ( f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_ps{self.F_page_size}_" + + ("gload_" if self.F_use_global_load else "bload_") + self.F_tile.name + "_" + self.F_pipeline.name @@ -632,6 +667,7 @@ class FmhaFwdKernel: kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, page_size=self.F_page_size, + use_global_load=self.F_use_global_load, ) @@ -714,8 +750,11 @@ class CustomFactory(KernelComponentFactory): def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl, - targets: Optional[List[str]] = None + kernel_filter: Optional[str], + receipt, + optdim_list, + mask_impl, + targets: Optional[List[str]] = None, ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing # (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with @@ -837,6 +876,25 @@ def get_fwd_blobs( api_pool.register_traits(k.api_trait()) gen.append(k) + # For page_size < kN0 (tile.F_bn0), also generate a GLOBAL_LOAD_LDS + # variant for >2GB KV cache support. The default (BUFFER_LOAD) uses SRD + # buffer_load (fast, <2GB). GLOBAL_LOAD_LDS uses global_load_lds_* + # (slower, handles >2GB). + if page_size < tile.F_bn0: + k_global_load = FmhaFwdKernel( + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + F_page_size=page_size, + F_use_global_load=True, + ) + api_pool.register_traits(k_global_load.api_trait()) + gen.append(k_global_load) + return (api_pool, gen) @@ -856,7 +914,9 @@ def write_blobs( optdim_list, mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + api_pool, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) @@ -871,7 +931,9 @@ def list_blobs( mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) + _, kernels = get_fwd_blobs( + kernel_filter, receipt, optdim_list, mask_impl, targets + ) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 6c842def58..98e2df2e1e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -673,6 +673,33 @@ struct fmha_batch_prefill_args ck_tile::index_t nhead_stride_kv_block_descale = 0; // Stride along num_kv_head dimension }; +// Selects the KV-cache load mode for a batch-prefill dispatch arm. +// GLOBAL_LOAD_LDS: required when (a) the page is smaller than one K/V tile +// so per-page SRD is impossible, AND (b) the total KV-pool byte size +// exceeds INT32_MAX so SRD's 32-bit byte offset cannot address it. +// BUFFER_LOAD: every other case — the SGPR-resident SRD path is fastest. +// Inputs are taken as plain integers so the helper has no template parameter +// and can be called from each codegen-emitted dispatcher arm with the arm's +// compile-time kN0 / element_bytes substituted as constants. +inline ck_tile::BlockAttentionKVCacheLoadModeEnum +fmha_batch_prefill_select_kv_load_mode(ck_tile::index_t page_block_size, + ck_tile::index_t kN0, + ck_tile::index_t num_total_pages, + ck_tile::index_t batch_stride_k, + ck_tile::index_t element_bytes) +{ + // Promote every operand to long_index_t so overflow is impossible regardless + // of multiplication order. A bare `static_cast(num_total_pages) + // * batch_stride_k * element_bytes` only works because of left-to-right + // associativity — a future reorder of the operands would silently truncate. + const auto kv_pool_bytes = static_cast(num_total_pages) * + static_cast(batch_stride_k) * + static_cast(element_bytes); + return (page_block_size < kN0 && kv_pool_bytes > INT32_MAX) + ? ck_tile::BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS + : ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD; +} + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -1457,7 +1484,9 @@ template + ck_tile::BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + ck_tile::BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + ck_tile::BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_ +CK_TILE_DEVICE void +async_global_load_lds_dwordxn(void* smem, const void* global_addr, bool_constant = {}) +{ +#if !defined(__gfx94__) && !defined(__gfx950__) + static_assert(always_false_v>, + "global_load_lds requires CDNA3+ (gfx940/gfx950). " + "Ensure kKVLoadMode is BUFFER_LOAD on this architecture."); +#endif + + static_assert(num_dwords == 1 || num_dwords == 4, + "global_load_lds supports num_dwords == 1 or 4 only " + "(2 dwords does not exist on any supported arch; " + "3 dwords only on CDNA4 and unused in FMHA pipeline)"); + +// Inline asm: only the global address is an explicit operand. The LDS +// destination is implicit via M0 (see contract above). `"=r"(smem)` is a +// SSA scheduling anchor only — `smem` is NOT written by this asm; the +// load goes to LDS at `M0[17:2]*4 + offset:0 + ThreadID*4`. +#define CK_TILE_GLOBAL_LOAD_LDS_INSTR(instr) \ + if constexpr(pre_nop) \ + asm volatile("s_nop 4\n" instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); \ + else \ + asm volatile(instr " %1, off offset:0" \ + : "=r"(smem) /*scheduling anchor; real LDS dest is M0*/ \ + : "v"(global_addr) \ + : "memory" /*prevents reorder across m0_{set,inc}*/); + + if constexpr(num_dwords == 1) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dword"); + } + else if constexpr(num_dwords == 4) + { + CK_TILE_GLOBAL_LOAD_LDS_INSTR("global_load_lds_dwordx4"); + } +#undef CK_TILE_GLOBAL_LOAD_LDS_INSTR +} + template CK_TILE_DEVICE thread_buffer diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index aa29345892..45131abb97 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -45,9 +45,29 @@ template > + typename YsGatherDims = sequence<0>, + bool kUseGlobalLoad_ = false> struct tile_scatter_gather { + static constexpr bool kUseGlobalLoad = kUseGlobalLoad_; + +#if !defined(__gfx94__) && !defined(__gfx950__) + // global_load_lds instruction is only available on CDNA3+ (gfx940/gfx950). + // On other architectures, kUseGlobalLoad must be false. + static_assert(!kUseGlobalLoad_, + "kUseGlobalLoad requires global_load_lds (CDNA3+: gfx940/gfx950). " + "This kernel should not be instantiated on this architecture."); +#endif + + // Empty placeholder used by the SRD instantiation so physical_pages_ and + // page_stride_elements_ occupy zero bytes there (combined with + // [[no_unique_address]] on the member declarations). Access sites are all + // inside `if constexpr(kUseGlobalLoad_)` arms, which compile out in SRD + // mode, so no caller needs to change. + struct gl_field_empty_t + { + }; + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -233,15 +253,22 @@ struct tile_scatter_gather const BottomTensorIndex& window_origin, const TileDstr& tile_distribution, const PageIdxArray& page_idx, - const ValidArray& valids) + const ValidArray& valids, + index_t page_stride_elements = 0) : bottom_tensor_view_{bottom_tensor_view}, window_lengths_{window_lengths}, window_origin_{window_origin}, tile_dstr_{tile_distribution}, page_idx_{page_idx}, + physical_pages_{}, + page_stride_elements_{}, valids_{valids}, pre_computed_coords_{} { + if constexpr(kUseGlobalLoad_) + { + page_stride_elements_ = page_stride_elements; + } #if 0 // debug // TODO: this use more register for FA, but less register for GEMM // need investigation @@ -357,6 +384,34 @@ struct tile_scatter_gather bottom_tensor_view_.buf_.p_data_ = data; } + // Override buffer size (input in RAW elements, NOT pre-divided by PackedSize) for + // SRD num_records control. Use to set max range when SRD is rebased per-tile + // (page_size >= kN0 path): each rebased SRD only needs to cover one page; without + // this the SRD claims validity for memory beyond the allocated buffer, which can + // fault on gfx950 page-table validation. + // + // Matches buffer_view ctor convention (buffer_view.hpp:245): input is raw element + // count and is divided by PackedSize before being stored. For PackedSize=1 + // (fp16/bf16/fp8) the division is a no-op; for PackedSize=2 (FP4 / packed int4) + // skipping it would over-report num_records by 2x and silently mask OOB on SRD + // reads. batch_prefill currently does not exercise the packed-type path, but this + // setter is generic infrastructure (lives in tile_scatter_gather.hpp) so it must + // honor the same invariant the ctor enforces. + CK_TILE_DEVICE constexpr void set_bottom_tensor_view_buffer_size(index_t size) + { + // Hint the optimizer that size is positive without inserting a runtime + // branch. Using assert() here corrupted gfx950 batch_prefill + // output: the __assert_fail handler's SGPR pressure forced the K-SRD + // register window to be reused as scratch and scattered the SRD writes + // across two conditional branches, which gfx950's packed + // buffer_load_dwordx4 issue window doesn't tolerate (gfx942 absorbs it + // via per-tile single-dword loads). __builtin_assume is hint-only — + // no branch, no scratch SGPRs, no codegen impact. + __builtin_assume(size > 0); + using BufType = remove_cvref_t; + bottom_tensor_view_.buf_.buffer_size_ = size / BufType::PackedSize; + } + // move thread's window adaptor coordinate and bottom tensor coordinate // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] template @@ -458,7 +513,21 @@ struct tile_scatter_gather // read from bottom tensor const vector_t vec_value = [&]() { - if constexpr(std::is_same_v) + if constexpr(kUseGlobalLoad_) + { + // Global load mode: 64-bit typed pointer arithmetic + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + vector_t v; + __builtin_memcpy(&v, addr, sizeof(vector_t)); + return v; + } + else if constexpr(std::is_same_v) { return get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, @@ -680,7 +749,23 @@ struct tile_scatter_gather const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor - if constexpr(std::is_same_v) + if constexpr(kUseGlobalLoad_) + { + // Global load mode: global_load_lds with 64-bit address + constexpr index_t vector_size = + sizeof(vector_t) / sizeof(uint32_t); // dwords per vector + const auto* base_ptr = get_bottom_tensor_view().buf_.p_data_; + const auto physical_page = physical_pages_[idx_gather]; + const auto coord_offset = bottom_tensor_thread_coord.get_offset(); + const long_index_t total_offset = + static_cast(physical_page) * page_stride_elements_ + + coord_offset + page_offset; + const auto* addr = base_ptr + total_offset; + // global_load_lds takes a byte address; addr (const DataType*) + // converts implicitly to const void*, no explicit cast needed. + async_global_load_lds_dwordxn(smem, addr, pre_nop_); + } + else if constexpr(std::is_same_v) { get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); @@ -1046,6 +1131,13 @@ struct tile_scatter_gather CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } + CK_TILE_DEVICE void update_physical_pages(const PageIdxArray& pages) + { + static_assert(kUseGlobalLoad_, + "global-load mode only; physical_pages_ is unused in SRD mode."); + physical_pages_ = pages; + } + CK_TILE_DEVICE void update_valids(const ValidArray& new_valids) { if constexpr(std::is_same_v == false) @@ -1139,7 +1231,29 @@ struct tile_scatter_gather // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] TileDstr tile_dstr_; + // Scatter/gather offsets for each element, set by update_page_idx(). + // SRD mode (kUseGlobalLoad=false): buffer_load(SRD, page_idx_[i] + coord). + // page_idx_[i] = within-page offset when kPageBlockSize >= kN0 (SRD rebased to page base) + // page_idx_[i] = page_base + within-page offset when kPageBlockSize < kN0 (full voffset) + // Global load mode (kUseGlobalLoad=true): page_idx_[i] = within-page offset only. + // Full address = base + physical_pages_[i] * page_stride_elements_ + page_idx_[i] + coord PageIdxArray page_idx_; + + // Physical page indices for global load mode (kUseGlobalLoad=true only). + // Maps each gather element to its physical page in a paged memory pool. + // Updated via update_physical_pages() before each load call. + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + physical_pages_; + + // Page stride in elements for global load mode (kUseGlobalLoad=true only). + // physical_pages_[i] * page_stride_elements_ gives the page base offset in elements. + // Set at construction time via the make_tile_scatter_gather overload that + // takes bool_constant; immutable thereafter. + // SRD mode: collapsed to gl_field_empty_t so the storage disappears. + [[no_unique_address]] std::conditional_t + page_stride_elements_; + ValidArray valids_; // this contains: @@ -1178,7 +1292,8 @@ template + index_t... YsGatherDims, + bool UseGlobalLoad = false> CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, @@ -1187,7 +1302,9 @@ make_tile_scatter_gather(const TensorView_& tensor_view, const StaticPageIndexArray_& page_idx, number, number, - sequence) + sequence, + bool_constant = {}, + index_t page_stride_elements = 0) { return tile_scatter_gather, remove_cvref_t, @@ -1196,11 +1313,17 @@ make_tile_scatter_gather(const TensorView_& tensor_view, std::nullptr_t, HsGatherDim, NumCoord, - sequence>{ - tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr}; + sequence, + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; } -// Legacy overload (compatible with original API) +// Legacy overload (compatible with original API, kUseGlobalLoad=false) template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + bool_constant, + index_t page_stride_elements = 0) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + std::nullptr_t, + 0, + 1, + sequence<0>, + UseGlobalLoad>{tensor_view, + window_lengths, + origin, + tile_distribution, + page_idx, + nullptr, + page_stride_elements}; +} + template ` — a value-template that is always `false` but whose +// evaluation is deferred until template instantiation. The canonical use is +// inside the `else` arm of an `if constexpr` chain or under an arch-gated +// `#if` to fire a `static_assert` ONLY when the offending instantiation is +// actually requested, e.g.: +// +// if constexpr (...) { ... } +// else { static_assert(always_false_v, "unsupported T"); } +// +// A bare `static_assert(false, ...)` would fire at template-definition +// parse time on conforming compilers, breaking the whole TU. +template +inline constexpr bool always_false_v = false; + // remove_cvref_t template using remove_reference_t = typename std::remove_reference::type; diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 8a5d77bf46..59e868f678 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp new file mode 100644 index 0000000000..826cd106f1 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +// KV cache load addressing mode selector for batch_prefill / paged-attention pipelines. +// - BUFFER_LOAD: SGPR-based SRD via buffer_load_* (default; 32-bit byte addressing, <2GB pool) +// - GLOBAL_LOAD_LDS: direct global_load_lds_* (64-bit addressing, required for >2GB KV cache) +enum class BlockAttentionKVCacheLoadModeEnum +{ + BUFFER_LOAD = 0, + GLOBAL_LOAD_LDS = 1, +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 4f2d3d58c2..8aa6d17dc3 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" @@ -134,7 +135,8 @@ template + index_t kVectorSize, + bool kUseGlobalLoad_ = false> CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physical_pages, const index_t& stride_token, const index_t& stride_page_block, @@ -156,81 +158,65 @@ CK_TILE_HOST_DEVICE void kv_offset_array_transform(const IndexArrayType& physica const index_t& thread_coord_start = coord_vec[kCoordAxis]; constexpr index_t kInPageOffsetMask = (1 << kLog2PageSize) - 1; - if constexpr(kIsKcache) - { - // K cache: per-token lookup - // Each token may be on a different page, so we use physical_pages[k0] for each. - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + // Addressing strategy — four cases controlled by (kPageBlockSize vs kN0, kUseGlobalLoad_): + // + // Case 1: kPageBlockSize >= kN0 + // SRD is rebased per-tile to the page base (rebase_{k,v}_window in caller). + // Page base is absorbed into the SRD's 48-bit base pointer (SGPR-resident). + // This function writes within-page offset only. + // + // Case 2: kPageBlockSize < kN0 && kUseGlobalLoad_ + // SRD cannot be rebased (multi-page wave). Loads use global_load_lds_*; the full + // 64-bit address is computed by tile_scatter_gather::load() in + // include/ck_tile/core/tensor/tile_scatter_gather.hpp from physical_pages_ + + // page_stride_elements_. This function writes within-page offset only. + // + // Case 3: kPageBlockSize < kN0 && !kUseGlobalLoad_ (kNeedFullOffset == true) + // SRD base is the entire KV buffer; the only place to encode page identity + // is the voffset itself. This function writes the FULL offset: + // page * stride_page_block + within_page + // Limited to <2GB total KV bytes by 32-bit voffset hardware width. + // + // Case 4: kPageBlockSize >= kN0 && kUseGlobalLoad_ + // Not emitted by codegen. Backstop static_assert in + // BlockFmhaBatchPrefillPipelineQRKSVSAsync. + constexpr bool kNeedFullOffset = (kPageBlockSize < kN0) && !kUseGlobalLoad_; - if constexpr(kPageBlockSize >= kN0) + static_for<0, kLoopCount, 1>{}([&](auto k0) { + const index_t global_token_idx = + global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; + const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + + // Within-page offset (layout-dependent for V cache with VECTORIZED_LAYOUT) + const index_t within_page = [&]() { + if constexpr(!kIsKcache && kKVMemoryLayout == + BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - kv_offset_vec[k0] = token_idx_in_page * stride_token; + return (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + + (token_idx_in_page % kVectorSize); } else { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - kv_offset_vec[k0] = - physical_page * stride_page_block + token_idx_in_page * stride_token; + return token_idx_in_page * stride_token; } - }); - } - else // V cache - { - // V cache: use physical_pages[k0] for each token - // physical_pages was already populated correctly by load_physical_pages(), handling: - // - page_size=1: page_idx maps token_idx -> physical_page directly - // - V tile crosses pages: per-token page lookup - // - V tile in single page: lane0 lookup with broadcast to all lanes - static_for<0, kLoopCount, 1>{}([&](auto k0) { - const index_t global_token_idx = - global_seq_offset + thread_coord_start + kLoopStart + kLoopStride * k0.value; - const index_t token_idx_in_page = global_token_idx & kInPageOffsetMask; + }(); - if constexpr(kPageBlockSize >= kN0) - { - // SRD rebasing mode: within-page offset only. - // The full page base is handled by rebasing the SRD pointer. - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = token_offset; - } - else - { - kv_offset_vec[k0] = token_idx_in_page * stride_token; - } - } - else - { - // Full global offset (original code path for ps1, ps16, etc.) - const index_t physical_page = physical_pages[k0]; - const long_index_t page_base_offset = - static_cast(physical_page) * stride_page_block; - - if constexpr(kKVMemoryLayout == - BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT) - { - const index_t token_offset = - (token_idx_in_page / kVectorSize) * (stride_token * kVectorSize) + - (token_idx_in_page % kVectorSize); - kv_offset_vec[k0] = page_base_offset + token_offset; - } - else - { - kv_offset_vec[k0] = page_base_offset + token_idx_in_page * stride_token; - } - } - }); - } + // SRD + page_size < kN0: add page base to form complete voffset for buffer_load. + // + // 32-bit by hardware: SRD buffer_load voffset is fundamentally 32-bit (CDNA3 MUBUF + // microcode format), so this branch is only reachable when total KV bytes fit in + // INT32_MAX. The kUseGlobalLoad_ template path handles the >2GB case via 64-bit + // global_load_lds_*; widening kv_offset_vec here would not lift the 2GB ceiling + // because the hardware truncates voffset regardless. + if constexpr(kNeedFullOffset) + { + kv_offset_vec[k0] = physical_pages[k0] * stride_page_block + within_page; + } + else + { + kv_offset_vec[k0] = within_page; + } + }); } // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) @@ -270,10 +256,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; static constexpr index_t kPageBlockSize = Problem::kPageBlockSize; static constexpr index_t kVectorSize = Problem::kVectorSize; - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto I3 = number<3>{}; + // Single load-mode selector for the whole pipeline. GLOBAL_LOAD_LDS routes K/V + // tiles through global_load_lds_* (handles >2GB KV cache); BUFFER_LOAD uses SRD + // buffer_load_*. The enum is named at the trait/Problem level; internally we + // derive a bool helper to keep `if constexpr` sites narrow. Codegen only emits + // GLOBAL_LOAD_LDS arms when page_size < kN0; the static_assert is a backstop. + static constexpr auto kKVLoadMode = Problem::kKVLoadMode; + static constexpr bool kUseGlobalLoad = + (kKVLoadMode == BlockAttentionKVCacheLoadModeEnum::GLOBAL_LOAD_LDS); + static_assert(!kUseGlobalLoad || (kPageBlockSize < kN0), + "GLOBAL_LOAD_LDS load mode is only valid when kPageBlockSize < kN0; " + "codegen should not emit this instantiation otherwise."); + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); static constexpr bool kIsGroupMode = Problem::kIsGroupMode; @@ -626,19 +623,26 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window.get_window_lengths(), k_dram_block_window.get_window_origin(), k_dist, - k_offsets); // K DRAM tile window for + k_offsets, + bool_constant{}, + page_stride_k); + if constexpr(kUseGlobalLoad) + { + k_dram_window.update_physical_pages(k_physical_pages); + } k_dram_window.init_raw(); - // SRD rebasing: move the buffer descriptor base pointer to each page's start address - // using 48-bit pointer arithmetic, so voffset only needs the small within-page offset. - // Only applies when kPageBlockSize >= kN0 (all threads in a wave access the same page). + // SRD rebasing for K: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_k_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { @@ -649,24 +653,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_k; window.set_bottom_tensor_view_data_ptr(page_ptr); + // Limit SRD num_records to one page worth of elements. + // Without this, the SRD claims validity for [page_ptr, page_ptr + + // full_buffer_size), which extends far beyond the allocated buffer when rebased to + // high pages. On gfx950, the hardware may validate the full SRD range against page + // table permissions, causing faults on freed/protected memory beyond the buffer. + window.set_bottom_tensor_view_buffer_size(page_stride_k); window.init_raw(); } }; + // SRD rebasing for V: only for page_size >= kN0 (all threads on same page). + // For page_size < kN0: either flat loads (kUseGlobalLoad) or full offsets handle + // addressing. auto rebase_v_window = [&](auto& window, index_t physical_page) { if constexpr(kPageBlockSize >= kN0) { + // readfirstlane: make physical_page provably wave-uniform so the + // resulting SRD lands in SGPRs (required by buffer load instructions). physical_page = __builtin_amdgcn_readfirstlane(physical_page); const auto* base_ptr = v_dram_block_window_tmp.get_bottom_tensor_view().buf_.p_data_; const auto* page_ptr = base_ptr + static_cast(physical_page) * page_stride_v; window.set_bottom_tensor_view_data_ptr(page_ptr); + window.set_bottom_tensor_view_buffer_size(page_stride_v); window.init_raw(); } }; - // Initial K SRD rebase + // Initial K SRD rebase (no-op for page_size < kN0, uses flat loads instead) rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); constexpr auto k_oob_ck = bool_constant{}; @@ -874,12 +890,13 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>(v_physical_pages_k2, - stride_v, - page_stride_v, - v_coord, - v_offsets_k2, - current_seq_k); + kVectorSize, + kUseGlobalLoad>(v_physical_pages_k2, + stride_v, + page_stride_v, + v_coord, + v_offsets_k2, + current_seq_k); static_for<0, V_KIterInner, 1>{}([&](auto k1) { constexpr auto idx = number{}; @@ -899,9 +916,20 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, false, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( v_physical_pages, stride_v, page_stride_v, v_coord, v_offsets, current_seq_k); } + + // v_offsets semantics — see the four-case addressing-strategy block above + // kNeedFullOffset in kv_offset_array_transform. Three cases reach this lambda: + // Case 1 (kPageBlockSize >= kN0): within-page offset; page base in SRD. + // Case 2 (page_size < kN0, kUseGlobalLoad): within-page offset; page base computed + // by tile_scatter_gather::load() from + // physical_pages_. + // Case 3 (page_size < kN0, !kUseGlobalLoad, == kNeedFullOffset): + // FULL offset (page * stride + within), + // carried in the 32-bit voffset (<2GB cap). }; // Prefetch V physical pages early to hide buffer load latency @@ -915,11 +943,32 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_offsets, number<1>{}, // HsGatherDim number<1>{}, // NumCoord - VPageIndexYDims); + VPageIndexYDims, + bool_constant{}, + page_stride_v); + if constexpr(kUseGlobalLoad) + { + v_dram_window.update_physical_pages(v_physical_pages); + } - // Initial V SRD rebase + // Initial V SRD rebase. Single source of truth: rebase_v_window's own + // `if constexpr(kPageBlockSize >= kN0)` makes this a no-op for case 2/3. + // Do not re-add an outer guard here — it would duplicate the inner check + // and drift if the lambda's gating condition ever changes. rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + // Save the *current* tile's V physical pages into v_dram_window before + // prefetch_v_physical_pages overwrites the v_physical_pages buffer with the + // *next* tile's pages. Case-2 only (kUseGlobalLoad); case-1/3 don't read + // physical_pages_ from the window. Encapsulating the save+prefetch pair + // here makes the ordering invariant unmissable when a fourth prefetch site + // is added later. + auto save_and_prefetch_v_pages = [&](auto k_loop_start) { + if constexpr(kUseGlobalLoad) + v_dram_window.update_physical_pages(v_physical_pages); + prefetch_v_physical_pages(k_loop_start); + }; + // prefetch K tile async_load_tile_raw( k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); @@ -972,7 +1021,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } // Prefetch V physical pages early - overlaps with GEMM0 computation - prefetch_v_physical_pages(number{}); + save_and_prefetch_v_pages(number{}); // STAGE 1, QK gemm clear_tile(s_acc); // initialize C @@ -1166,7 +1215,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages early - overlaps with softmax computation if constexpr(k1_loops > 1) { - prefetch_v_physical_pages(number<2 * kK1>{}); + save_and_prefetch_v_pages(number<2 * kK1>{}); } auto m_local = block_tile_reduce( @@ -1220,8 +1269,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync v_dram_window, {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); update_v_offsets(number<2 * kK1>{}); v_dram_window.update_page_idx(v_offsets); rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); @@ -1390,8 +1438,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) { - v_buf = load_tile( - v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); // Update V offsets using previously prefetched physical pages update_v_offsets(number<(2 + i_k1.value) * kK1>{}); v_dram_window.update_page_idx(v_offsets); @@ -1401,7 +1448,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // Prefetch V physical pages for NEXT iteration - overlaps with GEMM1 if constexpr(i_k1 + 1 < k1_loops - 1) { - prefetch_v_physical_pages(number<(2 + i_k1.value + 1) * kK1>{}); + save_and_prefetch_v_pages(number<(2 + i_k1.value + 1) * kK1>{}); } block_sync_lds(); @@ -1481,9 +1528,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync kKVMemoryLayout, true, kN0, - kVectorSize>( + kVectorSize, + kUseGlobalLoad>( k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); + if constexpr(kUseGlobalLoad) + k_dram_window.update_physical_pages(k_physical_pages); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); // After sink→window transition (i_total_loops == num_sink_loop), V window diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 87db7b85b9..c441f57c86 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -117,6 +117,12 @@ struct BlockFmhaBatchPrefillPipelineProblem static_assert((kPageBlockSize & (kPageBlockSize - 1)) == 0, "kPageBlockSize must be power of two"); + // KV cache load addressing mode. GLOBAL_LOAD_LDS handles >2GB pools via + // 64-bit addressing; BUFFER_LOAD (default) uses SRD buffer_load for the + // <2GB fast path. The 2GB bound = INT32_MAX byte offset, matching CK's + // existing TwoGB convention. + static constexpr auto kKVLoadMode = Traits_::kKVLoadMode; + static constexpr index_t kVectorSize = 16 / sizeof(KDataType_); // Dwordx4 static constexpr auto kKVMemoryLayout = Traits_::kKVMemoryLayout; static constexpr auto kKVLookupTable = Traits_::kKVLookupTable; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 7df39c3d11..e7370cdb65 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_attention_kv_load_mode_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_kvcache_layout_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" @@ -58,7 +59,9 @@ template + BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D, + BlockAttentionKVCacheLoadModeEnum kKVLoadMode_ = + BlockAttentionKVCacheLoadModeEnum::BUFFER_LOAD> struct TileFmhaBatchPrefillTraits : public TileFmhaTraits Date: Fri, 24 Apr 2026 16:22:28 +0800 Subject: [PATCH 7/9] [CK] Fix CI Failures for PR From Forks (#6701) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Fork PRs fail CI when `RUN_AITER_TESTS` or `RUN_FA_TESTS` is enabled. The docker scripts run `git clone -b "$CK_*_BRANCH" https://github.com/ROCm/rocm-libraries.git`, but a fork's branch doesn't exist upstream: ``` fatal: Remote branch not found in upstream origin ``` Example: [PR #6529 build #4](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-6529/4/pipeline). ## Technical Details **`Jenkinsfile`** — for PRs, use the upstream-visible PR ref instead of the head branch name: ```groovy CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME) ``` **`Dockerfile.aiter` / `Dockerfile.fa`** — `git clone -b ` only accepts branches (`refs/heads/*`) and tags (`refs/tags/*`), so it can't resolve `refs/pull/N/head`. Switch to `git fetch`, which accepts any refspec (and still works for plain branch names): ```sh mkdir rocm-libraries && cd rocm-libraries git init -q git remote add origin https://github.com/ROCm/rocm-libraries.git git fetch --depth 1 --filter=blob:none origin "$CK_*_BRANCH" git sparse-checkout init --cone git sparse-checkout set projects/composablekernel git checkout FETCH_HEAD ``` `git checkout FETCH_HEAD` lands in detached HEAD, which breaks the existing `git branch -m "$CK_*_BRANCH"` (and that name isn't a valid local branch anyway). Decouple the local branch name from the upstream ref: - Replace `git init` + `git branch -m` with `git init -b "$LOCAL_BRANCH"` (requires git ≥ 2.28, satisfied by base images) - `LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}"` in the rocm-libraries path; `LOCAL_BRANCH="$CK_*_BRANCH"` in the fallback - Downstream `git clone -b ... ../ck` uses `$LOCAL_BRANCH` ## Test Plan Manually trigger a build on this PR with `RUN_AITER_TESTS=true` and `RUN_FA_TESTS=true`; both docker images should build end-to-end. ## Test Result [jenkins / rocm-libraries-folder/Composable Kernel / PR-6701 / #3](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-6701/3/pipeline/) ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- Dockerfile.aiter | 18 +++++++++++------- Dockerfile.fa | 18 +++++++++++------- Jenkinsfile | 2 +- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 8d6e995656..4fcebc9033 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -10,27 +10,31 @@ RUN pip install pandas zmq einops ninja tabulate vcs_versioning && \ sudo mkdir /home/jenkins/workspace && \ cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \ if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \ - git clone --depth 1 -b "$CK_AITER_BRANCH" --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \ - cd rocm-libraries && \ + mkdir rocm-libraries && cd rocm-libraries && \ + git init -q && \ + git remote add origin https://github.com/ROCm/rocm-libraries.git && \ + git fetch --depth 1 --filter=blob:none origin "$CK_AITER_BRANCH" && \ git sparse-checkout init --cone && \ git sparse-checkout set projects/composablekernel && \ - git checkout "$CK_AITER_BRANCH" && \ + git checkout FETCH_HEAD && \ ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \ + LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \ mv projects/composablekernel ../ck && \ cd ../ck && rm -rf ../rocm-libraries && \ - git init && \ + git init -b "$LOCAL_BRANCH" && \ git config user.name "assistant-librarian[bot]" && \ git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \ - git branch -m "$CK_AITER_BRANCH" && git add -A && \ + git add -A && \ git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" ; \ else \ - git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck ; \ + git clone --depth 1 -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git ck && \ + LOCAL_BRANCH="$CK_AITER_BRANCH" ; \ fi && \ cd /home/jenkins/workspace && rm -rf aiter && \ git clone --depth 1 -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \ cd aiter && \ rm -rf 3rdparty/composable_kernel/ && \ - git clone -b "$CK_AITER_BRANCH" ../ck 3rdparty/composable_kernel/ && \ + git clone -b "$LOCAL_BRANCH" ../ck 3rdparty/composable_kernel/ && \ python3 setup.py develop && \ groupadd -g 1001 jenkins && \ useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ diff --git a/Dockerfile.fa b/Dockerfile.fa index 47643310bd..025bbd414e 100644 --- a/Dockerfile.fa +++ b/Dockerfile.fa @@ -12,27 +12,31 @@ RUN set -x ; \ sudo mkdir /home/jenkins/workspace && \ cd /home/jenkins/workspace && rm -rf rocm-libraries ck && \ if [ "$CK_FROM_ROCM_LIBRARIES" = "1" ]; then \ - git clone --depth 1 -b "$CK_FA_BRANCH" --no-checkout --filter=blob:none https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \ - cd rocm-libraries && \ + mkdir rocm-libraries && cd rocm-libraries && \ + git init -q && \ + git remote add origin https://github.com/$CK_FA_ORIGIN/rocm-libraries.git && \ + git fetch --depth 1 --filter=blob:none origin "$CK_FA_BRANCH" && \ git sparse-checkout init --cone && \ git sparse-checkout set projects/composablekernel && \ - git checkout "$CK_FA_BRANCH" && \ + git checkout FETCH_HEAD && \ ROCM_LIBRARIES_SHA=$(git rev-parse --short HEAD) && \ + LOCAL_BRANCH="ck-import-${ROCM_LIBRARIES_SHA}" && \ mv projects/composablekernel ../ck && \ cd ../ck && rm -rf ../rocm-libraries && \ - git init && \ + git init -b "$LOCAL_BRANCH" && \ git config user.name "assistant-librarian[bot]" && \ git config user.email "assistant-librarian[bot]@users.noreply.github.com" && \ - git branch -m "$CK_FA_BRANCH" && git add -A && \ + git add -A && \ git commit -m "import from ROCm/rocm-libraries@$ROCM_LIBRARIES_SHA" > /dev/null ; \ else \ - git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck ; \ + git clone --depth 1 -b "$CK_FA_BRANCH" https://github.com/$CK_FA_ORIGIN/composable_kernel.git ck && \ + LOCAL_BRANCH="$CK_FA_BRANCH" ; \ fi && \ cd /home/jenkins/workspace && rm -rf flash-attention && \ git clone --depth 1 -b "$FA_BRANCH" --recursive "https://github.com/$FA_ORIGIN/flash-attention.git" && \ cd flash-attention && \ rm -rf csrc/composable_kernel/ && \ - git clone -b "$CK_FA_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \ + git clone -b "$LOCAL_BRANCH" ../ck csrc/composable_kernel/ && git add csrc/composable_kernel && \ MAX_JOBS=$(nproc) GPU_ARCHS="$GPU_ARCHS" /opt/venv/bin/python3 -u -m pip install --no-build-isolation -v . && \ groupadd -g 1001 jenkins && \ useradd -u 1001 -g 1001 -m -s /bin/bash jenkins && \ diff --git a/Jenkinsfile b/Jenkinsfile index 170e0bf432..05eb7f97ef 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1191,7 +1191,7 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_ 0 13 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true 0 11 * * * % RUN_FULL_CONV_TILE_TESTS=true;RUN_AITER_TESTS=true;RUN_FA_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true 0 9 * * * % RUN_PYTORCH_TESTS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX101=false;BUILD_GFX103=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" -CURRENT_BRANCH_NAME = env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME +CURRENT_BRANCH_NAME = env.CHANGE_ID ? "refs/pull/${env.CHANGE_ID}/head" : (env.CHANGE_BRANCH ? env.CHANGE_BRANCH : env.BRANCH_NAME) POLL_SPEC = BRANCH_NAME == "develop" ? 'H H/6 * * *' : '' From c92fd392993b9cbf7565a82c4254b4087807c7c8 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Sat, 25 Apr 2026 00:30:41 +0800 Subject: [PATCH 8/9] Improve the performance of qr_ks_vs_whole_k_prefetch pipeline (#6209) ## About qr_ks_vs_whole_k_prefetch pipeline This PR updates and enhances the qr_ks_vs_whole_k_prefetch pipeline to improve performance on both MI350 GPUs through better MFMA instruction usage, transposed V-loading support, and N0-loop implementation. The pipeline targets scenarios where the number of workgroups is low, enabling better CU occupancy by using smaller MTile sizes (kM0=64 vs 128) while prefetching entire K tiles. ## Changes: - Adds transposed V-loading support (qr_ks_vs_whole_k_prefetch_trload) to avoid using shuffle instructions on MI350 - Implements N0-loop based Gemm0 to reduce tile window movement overhead and eliminate `clear_tile` calls - Adds full support for hdim96/hdim160 without padding requirements - Updates MFMA instruction selection to ensure optimal choices for MI350 ## Performance results 1. For attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch_trload` shows much better performance than `qr_ks_vs_async_trload` on the same case (execution time `41.02ms` by whole_k_prefetch_trload & `58.50ms` by async_load), and `qr_ks_vs_async_whole_k_prefetch_trload` also shows obviously better performance than the recently tuned `qr_ks_vs_async` on the same case (execution time `41.02ms` by whole_k_prefetch_trload 7 `47.60ms` by qr_ks_vs_async) 2. Also on MI300, for attention shapes which leads to kM0=64, `qr_ks_vs_async_whole_k_prefetch` shows much better performance than the `qr_ks_vs_async` (which is supposed to be very high-efficient) on the same case (execution time `64.50ms` by whole_k_prefetch & `80.20ms` by qr_ks_vs_async) 3. For attention shapes which leads to kM0=128, `qr_ks_vs_async_whole_k_prefetch_trload` show a little bit better performance than `qr_ks_vs_async` on mi350 (execution time `104.50ms` by whole_k_prefetch_trload & `106.50ms` by qr_ks_vs_async). And they shows completely on-par performance on MI300 ## Test/Verify 1. Use the ROCM xformers branch `test_whole_k_prefetch_n0loop` to test/verify qr_ks_vs_whole_k_prefetch pipeline since this pipeline can not be used by ck_tile fmha example so far 2. Use the following command-line for building/testing xformers >```bash > #> git clone -b test_whole_k_prefetch_n0loop https://github.com/ROCm/xformers > #> git submodule update --init --recursive > #> pip install --no-build-isolation -e ./ > #> pytest tests/test_mem_eff_attention.py::test_forward >``` 4. Any scripts which can run on xformers can be used to evaluate qr_ks_vs_whole_k_prefetch pipeline. Using the two environ variable to switch from using different pipelines > ```bash > #> export FMHA_DISABLE_SPECIAL_TREATMENT=1 #> to disable using FAV3 and qr_ks_vs_async_trload pipeline > #> export FMHA_ENABLE_ASYNC_PIPELINE=1 #> to disable using qr_ks_vs_async pipeline for comparing > ``` ## Discussion --------- Co-authored-by: Po Yen Chen Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: poyenc <1132573+poyenc@users.noreply.github.com> Co-authored-by: qianfengz <12429178+qianfengz@users.noreply.github.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- include/ck_tile/host/rotating_buffers.hpp | 1 + include/ck_tile/ops/fmha.hpp | 1 + .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 189 +++- .../pipeline/block_fmha_pipeline_problem.hpp | 46 + ...mha_pipeline_qr_ks_vs_whole_k_prefetch.hpp | 1000 ++++++++--------- ..._ks_vs_whole_k_prefetch_default_policy.hpp | 911 ++++++++++----- ...eline_qr_ks_vs_whole_k_prefetch_trload.hpp | 861 ++++++++++++++ .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 5 +- .../ops/fmha/pipeline/tile_fmha_shape.hpp | 2 +- ...ock_gemm_areg_bsmem_creg_v2_prefetch_k.hpp | 268 +++++ ...ock_gemm_areg_bsmem_creg_v2_prefetch_n.hpp | 239 ++++ ...m_areg_bsmem_trload_creg_v2_prefetch_n.hpp | 243 ++++ 12 files changed, 2921 insertions(+), 845 deletions(-) create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp diff --git a/include/ck_tile/host/rotating_buffers.hpp b/include/ck_tile/host/rotating_buffers.hpp index baec4b45e8..32745ee424 100644 --- a/include/ck_tile/host/rotating_buffers.hpp +++ b/include/ck_tile/host/rotating_buffers.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/hip_check_error.hpp" #include +#include namespace ck_tile { diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 59e868f678..cf651312d9 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -56,6 +56,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index b04205f2c2..b7dcdb3648 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -32,6 +32,83 @@ namespace ck_tile { +namespace detail { + +// A helper struct for detecting n0loop +template +struct has_n0loop_flag : std::false_type +{ +}; + +template +struct has_n0loop_flag< + T, + std::enable_if_t && T::kUseN0Loop>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_n0loop_pipeline_v = has_n0loop_flag::value; + +// A helper struct for detecting ignore_fast_exp2 flag +template +struct has_ignore_fast_exp2_flag : std::false_type +{ +}; + +// IgnoreFastExp2 is used by some pipeline which explicitly chooses not to use FAST_EXP2; +// By detecting the kIgnoreFastExp2 from the pipeline, the kernel's MakeKargsImpl() interface +// is able to avoid passing an in-correct scale_s parameter to the kernel layer +template +struct has_ignore_fast_exp2_flag< + T, + std::enable_if_t && + T::kIgnoreFastExp2>> : std::true_type +{ +}; + +template +static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag::value; + +// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of +// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256 +// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline +template +struct has_naive_hdim_load_flag : std::false_type +{ +}; + +template +struct has_naive_hdim_load_flag< + T, + std::enable_if_t && + T::kIsNaiveHDimLoad>> : std::true_type +{ +}; + +template +static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag::value; + +// A helper struct for detecting kUseTrLoad +template +struct has_use_trload_flag : std::false_type +{ +}; + +template +struct has_use_trload_flag< + T, + std::enable_if_t && T::kUseTrLoad>> + : std::true_type +{ +}; + +template +static inline constexpr bool is_using_trload_v = has_use_trload_flag::value; + +} // namespace detail + template struct FmhaFwdKernel { @@ -77,13 +154,14 @@ struct FmhaFwdKernel static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + static constexpr bool kUseTrLoad = detail::is_using_trload_v; - static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad; #if defined(__gfx950__) static constexpr bool kIsAvailable = true; #else static constexpr bool kIsAvailable = !kUseTrLoad; #endif + static constexpr std::string_view kPipelineName = FmhaPipeline::name; template // to avoid duplicated base class prblem, introduce an template @@ -444,7 +522,9 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), + detail::ignore_fast_exp2_v + ? scale_s + : static_cast(scale_s * ck_tile::log2e_v<>), #else scale_s, #endif @@ -897,7 +977,9 @@ struct FmhaFwdKernel num_head_q, nhead_ratio_qk, #if CK_TILE_FMHA_FWD_FAST_EXP2 - static_cast(scale_s * ck_tile::log2e_v<>), + detail::ignore_fast_exp2_v + ? scale_s + : static_cast(scale_s * ck_tile::log2e_v<>), #else scale_s, #endif @@ -1039,6 +1121,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, const void* block_scale_seqstart_q_ptr, const void* block_scale_seqstart_k_ptr, + const void* seqstart_v_scale_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1097,6 +1180,7 @@ struct FmhaFwdKernel seqlen_k_ptr, block_scale_seqstart_q_ptr, block_scale_seqstart_k_ptr, + seqstart_v_scale_ptr, hdim_q, hdim_v, num_head_q, @@ -1158,6 +1242,7 @@ struct FmhaFwdKernel const void* seqlen_k_ptr, const void* block_scale_seqstart_q_ptr, const void* block_scale_seqstart_k_ptr, + const void* seqstart_v_scale_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, @@ -1216,6 +1301,7 @@ struct FmhaFwdKernel seqlen_k_ptr, block_scale_seqstart_q_ptr, block_scale_seqstart_k_ptr, + seqstart_v_scale_ptr, hdim_q, hdim_v, num_head_q, @@ -1602,6 +1688,10 @@ struct FmhaFwdKernel static_cast(i_nhead) * kargs.nhead_stride_o + batch_offset_o; + constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v + ? FmhaPipeline::kQKHeaddim + : FmhaPipeline::kSubQKHeaddim; + // Q/K/V DRAM and DRAM window const auto q_dram = [&]() { const auto q_dram_naive = make_naive_tensor_view( @@ -1612,10 +1702,10 @@ struct FmhaFwdKernel number<1>{}); if constexpr(FmhaPipeline::kQLoadOnce) { - return pad_tensor_view(q_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); } else { @@ -1634,10 +1724,21 @@ struct FmhaFwdKernel number<1>{}); constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - k_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + + if constexpr(detail::is_n0loop_pipeline_v) + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } }(); const auto v_dram = [&]() { if constexpr(std::is_same_v) @@ -1649,18 +1750,29 @@ struct FmhaFwdKernel number{}, number<1>{}); - const auto v_dram_transposed = transform_tensor_view( - v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen_k)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + if constexpr(!kUseTrLoad) + { + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; - return pad_tensor_view( - v_dram_transposed, - make_tuple(number{}, number{}), - sequence{}); + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false; + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }; } else { @@ -1683,17 +1795,28 @@ struct FmhaFwdKernel q_dram, [&]() { if constexpr(FmhaPipeline::kQLoadOnce) - return make_tuple(number{}, - number{}); + return make_tuple(number{}, number{}); else return make_tuple(number{}, number{}); }(), {i_m0, 0}); - auto k_dram_window = make_tile_window( - k_dram, - make_tuple(number{}, number{}), - {0, 0}); + auto k_dram_window = [&]() { + if constexpr(detail::is_n0loop_pipeline_v) + { + return make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + else + { + return make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + } + }(); auto v_dram_window = make_tile_window( v_dram, @@ -1843,7 +1966,10 @@ struct FmhaFwdKernel *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); #if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; + if constexpr(!detail::ignore_fast_exp2_v) + { + slope *= ck_tile::log2e_v<>; + } #endif if constexpr(kHasMask) { @@ -2826,7 +2952,10 @@ struct FmhaFwdKernel *(reinterpret_cast(kargs.alibi_slope_ptr) + i_batch_ * kargs.alibi_slope_stride + i_nhead_); #if CK_TILE_FMHA_FWD_FAST_EXP2 - slope *= ck_tile::log2e_v<>; + if constexpr(!detail::ignore_fast_exp2_v) + { + slope *= ck_tile::log2e_v<>; + } #endif if constexpr(kHasMask) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index c441f57c86..a8a8f96d3b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -9,6 +9,52 @@ namespace ck_tile { +namespace detail { + +template +CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize() +{ + if constexpr(std::is_same_v || std::is_same_v) + { + // ToDo: need support in ck_tile for using buffer_load_dwordx3 + // if constexpr(ElemPerThread % 6 == 0) + // return 6; + if constexpr(ElemPerThread % 8 == 0) + return 8; + else if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else if constexpr(std::is_same_v) + { + // ToDo: need support in ck_tile for using buffer_load_dwordx3 + // if constexpr(ElemPerThread % 3 == 0) + // return 3; + if constexpr(ElemPerThread % 4 == 0) + return 4; + else if constexpr(ElemPerThread % 2 == 0) + return 2; + return 1; + } + else + return 1; +}; + +template +CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize() +{ + constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize; + + return GetMaxVectorSize(); +} + +} // namespace detail + template ; using VDataType = remove_cvref_t; using SaccDataType = remove_cvref_t; - using SMPLComputeDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; using BiasDataType = remove_cvref_t; using RandValOutputDataType = remove_cvref_t; using LSEDataType = remove_cvref_t; @@ -34,12 +34,22 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch using VLayout = remove_cvref_t; static constexpr bool kQLoadOnce = true; static_assert(kQLoadOnce == Policy::QLoadOnce); + static_assert(!Problem::kUseTrLoad, "This pipeline does not use trload!"); + static_assert(sizeof(KDataType) == sizeof(VDataType) && + alignof(KDataType) == alignof(VDataType), + "K and V share the same LDS region; their element types must have identical " + "size and alignment."); + + static constexpr bool kUseN0Loop = true; + static constexpr bool kIgnoreFastExp2 = true; + static constexpr bool kIsNaiveHDimLoad = true; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kN0Sub = + BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0 static constexpr index_t kN1 = BlockFmhaShape::kN1; static constexpr index_t kK1 = BlockFmhaShape::kK1; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; @@ -47,35 +57,33 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + // since this pipeline is only used by the inference path of xformers, the Dropout function is + // not well tested with the pipeline, so here we have Dropout disabled + static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!"); + // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); + static constexpr index_t kAlignmentV = + Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentO = kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); static constexpr index_t kAlignmentBias = kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); - static constexpr index_t kAlignmentRandVal = - kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal(); static constexpr index_t kBlockPerCu = []() { if constexpr(Problem::kBlockPerCu != -1) @@ -135,9 +143,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch typename AttentionVariantParams, typename BlockIndices> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile const KElementFunction& k_element_func, const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VElementFunction& v_element_func, @@ -158,8 +166,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; + // xformers path does not require the pipeline to output random values for host + // verification, since a separate kernel is used to generate random values + ignore = randval_dram_block_window_tmp; static_assert( std::is_same_v> && @@ -168,8 +177,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -177,24 +186,51 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch "wrong!"); constexpr auto I0 = number<0>{}; - constexpr auto I1 = number<1>{}; - constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t n0_loops = kN0 / kN0Sub; constexpr index_t k1_loops = kN0 / kK1; - static_assert(2 <= k0_loops); - static_assert(2 <= k1_loops); + + // usually kN0 is 128, kN0Sub/kK1 is 32/16 + static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); + static_assert(n0_loops >= NumPrefetchV, "Check failed!"); + static_assert(k1_loops >= NumPrefetchV, "Check failed!"); constexpr bool kPreloadWholeNextIterationK = Policy::template IsPreloadWholeNextIterationK(); - constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); - constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers(); - constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); - static_assert(NumKLdsBuffers >= 2); + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), - q_dram_block_window_tmp.get_window_lengths(), + make_tuple(number{}, number{}), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQRegTileDistribution()); @@ -202,32 +238,38 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - auto k_dram_block_window = - make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), - k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + if(seqlen_k_end <= seqlen_k_start) + { + clear_tile(o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + return o_acc; + }; auto k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, Policy::template MakeKDramTileDistribution()); using k_tile_type = decltype(load_tile(k_dram_window)); + // only prefetch two k tiles to save vgprs consumption auto k_tiles = [&]() { if constexpr(kPreloadWholeNextIterationK) - return statically_indexed_array{}; + return statically_indexed_array{}; else return statically_indexed_array{}; }(); k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); auto q_tile = load_tile(q_dram_window); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); + + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; // K tile in LDS KDataType* k_lds_ptr = static_cast(smem_ptr); @@ -236,612 +278,461 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch auto k_lds_window = make_tile_window( k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); - using k_lds_window_type = - decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + using k_lds_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array k_lds_windows; + statically_indexed_array k_lds_windows; - static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { - k_lds_windows[i_buf] = get_slice_tile( - k_lds_window, sequence{}, sequence<(i_buf + 1) * kN0, kK0>{}); + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); }); - auto v_dram_window = - make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); // V tile in LDS auto v_lds = make_tensor_view( - reinterpret_cast(static_cast(smem_ptr) + - Policy::template GetExclusiveKLdsBytes()), + reinterpret_cast(smem_ptr), Policy::template MakeVLdsBlockDescriptor()); auto v_lds_window = make_tile_window( v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); - using v_tile_type = decltype(load_tile(v_dram_window)); - - statically_indexed_array v_tiles; - using v_lds_window_type = decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); - statically_indexed_array v_lds_windows; + statically_indexed_array v_lds_windows; - static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { v_lds_windows[i_buf] = get_slice_tile( v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); }); - // Block GEMM - constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {0, seqlen_k_start}, + Policy::template MakeVDramTileDistribution()); - using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); - auto s_acc = SaccBlockTileType{}; - - // reduction function for softmax - const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; - const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; - - // infer Sacc, S, P, M, L, Oacc type - using SBlockTileType = decltype(cast_tile(s_acc)); - - using MLBlockTileType = decltype(block_tile_reduce( - SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); - - using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); - - // init Oacc, M, L - auto o_acc = OaccBlockTileType{}; - auto m = MLBlockTileType{}; - auto l = MLBlockTileType{}; + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; clear_tile(o_acc); - set_tile(m, -numeric::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - - // check early exit if no work to do - if constexpr(FmhaMask::IsMasking || kPadSeqLenK) - { - if(num_total_loop <= 0) - { - if constexpr(kStoreLSE) - { - auto lse = - make_static_distributed_tensor(m.get_tile_distribution()); - - set_tile(lse, -numeric::infinity()); - - store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); - } - - // Note: here occ are all cleard, return it - // Note: q loaded but no fence, ignore it. - return o_acc; - } - } - const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), - bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N - Policy::template MakeBiasDramTileDistribution()); + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, + Policy::template MakeBiasDramTileDistribution()); - auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + // assuming no random values need be saved, this is true when the pipeline is called from + // xformers, since we have a separate kernel to generated random values + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to pass a null_randval_dram and tile window to the BlockDropout operator to + // make it works + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); q_tile = tile_elementwise_in(q_element_func, q_tile); - index_t i_total_loops = 0; + auto seqlen_k_curr = seqlen_k_start; + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; do { - if constexpr(kPreloadWholeNextIterationK) + // STAGE 1, Gemm_0 ( S = Q@K ) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 { - if(i_total_loops == 0) // executed by fist iteration + if(seqlen_k_curr == seqlen_k_start) // at first iteration { - if(num_total_loop > 1) // there are multiple iterations + if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; - if constexpr(i_k0 == 0) - clear_tile(s_acc); + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + // prefetch all k_tiles for next iteration + static_for<0, n0_loops, 1>{}([&](auto ii_n0) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }); + }; block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); - - store_tile( - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - - // prefetch all k_tiles for next iteration - static_for<0, k0_loops, 1>{}([&](auto i_k0) { - k_tiles[number{}] = load_tile(k_dram_window); - - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - }); - - move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); - - block_sync_lds(); - // execute last unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); } - else // there is only single iteration + else // the iteration is also the last iteration { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; - if constexpr(i_k0 == 0) - clear_tile(s_acc); + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); - - store_tile( - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - - // move_tile_window(k_dram_window, {0, -k0_loops * kK0}); - } + }; } - else // executed by intermediate and last iteration + else // at intermediate and last iteration { - if(i_total_loops < num_total_loop - 1) // intermediate iteration + if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration { - store_tile(k_lds_windows[I0], - tile_elementwise_in(k_element_func, k_tiles[I0])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; - clear_tile(s_acc); - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), - k_lds_windows[I0]); - - store_tile(k_lds_windows[I1], - tile_elementwise_in(k_element_func, k_tiles[I1])); - - move_tile_window(k_dram_window, {kN0, 0}); - - // prefetch first k_tile for next iteration - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - - k_tiles[I1] = load_tile(k_dram_window); - if constexpr(1 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, kK0>{}, sequence{}), - k_lds_windows[I1]); - - // during the gemm-loop, also prefetch other k_tiles for next iteration - static_for<2, k0_loops, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], - k_tiles[number{}]); - - k_tiles[number{}] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 1) - move_tile_window(k_dram_window, {0, kK0}); + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); - }); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); - move_tile_window(k_dram_window, {0, -(k0_loops - 1) * kK0}); + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); } else // last iteration { - store_tile(k_lds_windows[I0], - tile_elementwise_in(k_element_func, k_tiles[I0])); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - clear_tile(s_acc); - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, sequence<0, 0>{}, sequence{}), - k_lds_windows[I0]); - - static_for<1, k0_loops, 1>{}([&](auto i_k0) { - store_tile( - k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); }; - }; + } } - else // only preload one unroll of K for next iteration + else // only preload one unroll of K for next iteration, used when kM0=128 { - static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { - store_tile(k_lds_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); - if constexpr(i_k0 == 0) - clear_tile(s_acc); + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[I0]), + partition_index); - if constexpr(i_k0 < k0_loops - 1) + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_n0 < n0_loops - 1) + { k_tiles[I0] = load_tile(k_dram_window); - if constexpr(i_k0 < k0_loops - 2) - move_tile_window(k_dram_window, {0, kK0}); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); block_sync_lds(); - // execute current unroll of gemm_0 - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, i_k0 * kK0>{}, - sequence{}), - k_lds_windows[number{}]); + + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); }); + } - store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], - tile_elementwise_in(k_element_func, k_tiles[I0])); - - // prefetch first v_tile - v_tiles[I0] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - - block_sync_lds(); - gemm_0(s_acc, - get_slice_tile(q_tile, - sequence<0, (k0_loops - 1) * kK0>{}, - sequence{}), - k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); - }; - - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0x00000001); const auto bias_tile = load_tile(bias_dram_window); // load bias tile - static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { - v_tiles[i_buf] = load_tile(v_dram_window); - move_tile_window(v_dram_window, {0, kK1}); - }); - // STAGE 2, scale_s, add bias, mask, softmax if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + tile_elementwise_inout( - [&](auto& x, const auto& y) { -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - x += type_convert(bias_element_func(y)); -#else - x += log2e_v * - type_convert(bias_element_func(y)); -#endif + [&](auto& x, const auto y) { + x += type_convert(bias_element_func(y)); }, - s_acc, + pcomp_tile, bias_tile); } else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) { - const auto k_origin = k_dram_block_window.get_window_origin(); - constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); - sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) { const auto tile_idx = get_x_indices_from_distributed_indices( - s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); constexpr auto i_j_idx = make_tuple(idx0, idx1); - s_acc(i_j_idx) *= scale_s; - position_encoding.update(s_acc(i_j_idx), row, col); + pcomp_tile(i_j_idx) *= scale_s; + position_encoding.update(pcomp_tile(i_j_idx), row, col); }); }); } else { - s_acc = tile_elementwise_in(s_acc_element_func, s_acc); -#if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); -#endif + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { - const auto k_origin = k_dram_block_window.get_window_origin(); - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}), - number{}, - number{}); + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{}); if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); + set_tile_if(pcomp_tile, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); } } - const auto s = cast_tile(s_acc); // S{j} - auto m_local = block_tile_reduce( - s, - sequence<1>{}, - f_max, - -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); + __builtin_amdgcn_sched_barrier(0x00000001); - const auto m_old = m; // m{j-1} - tile_elementwise_inout( - [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + const auto m_old = m; - auto p_compute = make_static_distributed_tensor( - s.get_tile_distribution()); // Pcompute{j} - - static const auto get_validated_m = [](SMPLComputeDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - FmhaMask::IsMasking) - { - return raw_m == -numeric::infinity() - ? type_convert(0.f) - : raw_m; - } - else - { - return raw_m; - } - }; - - constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); - sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - auto row_max = scale_s * get_validated_m(m[i_idx]); -#endif - sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); - } - else - { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); - } -#else - p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); -#endif - }); - }); - - auto rowsum_p = block_tile_reduce( - p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); - // l{j}, Oacc{j} - constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - const auto tmp = [&]() { - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); - } - else - { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); - } - }(); -#else - const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); -#endif - l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - // FIXME: this use different equation from FA v2 paper, - // but produce correc result. - // Is the equation wrong? - o_acc(i_j_idx) *= tmp; - }); - }); - - if constexpr(kHasDropout) - { - auto randval_ptr = - reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeK(); - dropout.template Run( - smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); - } - - __builtin_amdgcn_sched_barrier(0x7f); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - - store_tile( - v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_windows[I0], - tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch - } + block_tile_reduce(m, pcomp_tile, sequence<1>{}, f_max); + block_tile_reduce_sync(m, f_max, bool_constant{}); __builtin_amdgcn_sched_barrier(0); - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + auto v_shuffled_tile = make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution()); + shuffle_tile(v_shuffled_tile, tile_elementwise_in(v_element_func, v_tiles[I0])); - if constexpr(!kPreloadWholeNextIterationK) + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) { - if(i_total_loops < num_total_loop - 1) - { - move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0}); - k_tiles[I0] = load_tile(k_dram_window); - move_tile_window(k_dram_window, {0, kK0}); - }; - - __builtin_amdgcn_sched_barrier(0); - } - - // STAGE 3, KV gemm - if constexpr(k1_loops > 1) - { - if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2 - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - v_tiles[I0] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_tiles[I0]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[I0])); - } - - move_tile_window(v_dram_window, {0, kK1}); - }); - } - else // NumVLdsBuffers == 3 or 2 - { - static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - if constexpr(i_k1 < k1_loops - NumPrefetchV) - v_tiles[number{}] = load_tile(v_dram_window); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_windows[number{}]); - - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); - store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); - } - else - { - store_tile( - v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], - tile_elementwise_in(v_element_func, - v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); - } - - if constexpr(i_k1 < k1_loops - NumPrefetchV) - move_tile_window(v_dram_window, {0, kK1}); - }); - } - } - // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); - - block_sync_lds(); - gemm_1(o_acc, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); - - if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) - { - __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); }; - } while(++i_total_loops < num_total_loop); + store_tile( + v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile, partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + seqlen_k_curr += kN0; + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }; + + if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + shuffle_tile(v_shuffled_tile, + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + v_shuffled_tile, + partition_index); + }; + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); // store lse if constexpr(kStoreLSE) @@ -851,19 +742,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); -#if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || - BiasEnum == BlockAttentionBiasEnum::ALIBI) - { - lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); - } - else - { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); - } -#else - lse(i_idx) = m_[i_idx] + log(l_[i_idx]); -#endif + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); }); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); @@ -874,17 +753,13 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); - const auto tmp = [&]() { - if constexpr(FmhaMask::IsMasking) - { - return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; - } - else - return 1 / l[i_idx]; - }(); sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) *= tmp; + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; }); }); @@ -916,8 +791,11 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { + ignore = sink_v; + return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp index 3f015a1c1a..e5e9e2333a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp @@ -4,17 +4,20 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp" namespace ck_tile { struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy - : BlockFmhaPipelineQXKSVSCustomPolicy { - static constexpr index_t NumPrefetchV = 2; + static constexpr bool QLoadOnce = true; // needed by the kernel + static constexpr bool AsyncCopy = false; // needed by the kernel template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsPreloadWholeNextIterationK() @@ -23,30 +26,38 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy }; template - CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumPrefetchV() { - return 2; - } + constexpr index_t n0_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK0; + constexpr index_t k1_loops = Problem::BlockFmhaShape::kN0 / Problem::BlockFmhaShape::kK1; - template - CK_TILE_DEVICE static constexpr auto GetNumPrefetchV() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t kN0 = BlockFmhaShape::kN0; - constexpr index_t kK1 = BlockFmhaShape::kK1; - - constexpr index_t k1_loops = kN0 / kK1; - - return min(NumPrefetchV, k1_loops); - } - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers() - { - return 2; + if constexpr(Problem::kUseTrLoad) + { + // kM0 is 64, kN0 is 128, prefetch all k_tiles + if constexpr(IsPreloadWholeNextIterationK()) + { + if constexpr(n0_loops >= 4 && k1_loops >= 6) + return 2; + return 2; + } + else // kM0 is 128, kN0 is 64, prefetch one k_tile + { + // kN0 == 64, try to prefetch more v_tiles + return 2; + }; + } + else + { + return 2; + }; }; + template + CK_TILE_HOST_DEVICE static constexpr auto GetNumKVLdsBuffers() + { + return 4; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() { @@ -57,195 +68,537 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy Problem::BlockFmhaShape::kQKHeaddim>(); } + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetQKWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::kKPerThread; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeCBlockTile() + .get_tile_distribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() + { + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return min(MaxVectorSize, WG::kK / WG::WarpGemmAttribute::Impl::kABKLane); + } + template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() { - using KDataType = remove_cvref_t; - return 8 / sizeof(KDataType); + if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() + { + using KDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + + return detail:: + GetDramTileAccessMaxVectorSize(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() + { + if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) + return 8; + else + return 4; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() + { + using VDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + // special consideration when shuffling is required before storing V to LDS + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t kMaxVecLoad = detail:: + GetDramTileAccessMaxVectorSize(); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + + // try to avoid writing sub-dword to LDS due to poor performance + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); + + return kVecLoad; + } + else + { + return detail:: + GetDramTileAccessMaxVectorSize(); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + return WG::WarpGemmAttribute::Impl::kCM1PerLane; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKVector = GetAlignmentK(); + + // for hdim96 and hdim160 + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + return kKPerBlock * kNPerBlock; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + return kKPerBlock * kNPerBlock; + } + else + { + static_assert(kKVector % kKPack == 0); + + return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + }; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() + { + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + + return N0 * (N1 * kKPerBlock + kKPack); + } + else + { + return kNPerBlock * kKPerBlock; + }; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize() + { + return max(GetKSingleSmemElementSpaceSize(), + GetVSingleSmemElementSpaceSize()); + }; + template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() { - constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); constexpr index_t kKVector = GetAlignmentK(); - static_assert(kKVector % kKPack == 0); + // for hdim96 and hdim160, use simplest layout + if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim) + { + constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; - constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); - constexpr auto k_lds_block_desc = transform_tensor_descriptor( - k_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - return k_lds_block_desc; + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else if constexpr(GetQKWarpGemmKPerThreadSize() >= 8) + { + static_assert(kKVector == kKPack); + + using KDataType = remove_cvref_t; + + constexpr index_t DataTypeSize = sizeof(KDataType); + +#ifdef __gfx950__ + // 256 contiguous bytes mapped to 64 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (64 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (64 * 4 / kKPerBlock / DataTypeSize); +#else + // 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes + constexpr auto NLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); +#endif + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + constexpr auto k_lds_block_desc_k0_nldslayer_n_k1 = transform_tensor_descriptor( + k_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{})); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_k0_nldslayer_n_k1, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<1, 3>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + else + { + static_assert(kKVector % kKPack == 0); + + constexpr index_t KSingleSmemElementSpaceSize = + kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; + + static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto k_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + }; } template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { - using KDataType = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); + constexpr index_t kKVector = GetAlignmentK(); + constexpr index_t OtherK = kKPerBlock / kKVector; - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - static_assert(0 < ElemPerThread); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() - { - using VDataType = remove_cvref_t; - - constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); - - constexpr index_t Banks = get_n_lds_banks(); - constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType); - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(PixelsPerRow % kKPack == 0); - constexpr index_t NPerRow = PixelsPerRow / kKPack; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - static_assert(kNPerBlock % NPerRow == 0); - static_assert(kKPerBlock % kKPack == 0); - - constexpr index_t VSingleSmemElementSpaceSize = - (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple( - make_merge_transform(make_tuple( - number{}, number{}, number{})), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return v_lds_block_desc; - } - - template - CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() - { - using VLayout = remove_cvref_t; - - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - - if constexpr(std::is_same_v) + if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim) + // for kKPerBlock=32,64,128,256 { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; // P + static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!"); - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); - constexpr index_t K3 = ElemPerThread / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; - if constexpr(get_warp_size() % (K2 * N0) == 0) - { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); - static_assert(kKPerBlock == K0 * K1 * K2 * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - else - { - constexpr index_t K1 = (K2 * N0) / get_warp_size(); - constexpr index_t K2_m = K2 / K1; - constexpr index_t K0 = kBlockSize / get_warp_size() / K1; - static_assert(kKPerBlock == K0 * K1 * K2_m * K3); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<0, 2>>, - sequence<2, 1>, - sequence<3, 1>>{}); - } - } - else - { - constexpr index_t K1 = GetAlignmentV(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; - constexpr index_t N1 = kBlockSize / get_warp_size(); - static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = kNPerBlock / (N2 * N1); - static_assert(N0 != 0); + constexpr index_t KPerThread = kKVector; + constexpr index_t KThreads = OtherK; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); } + else // for kKPerBlock=96,160 + { + static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!"); + + constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5; + constexpr index_t KThreads = OtherK / KRepPerThread; + + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; + + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); + + static_assert(kKPack >= K2, "Check failed!"); + + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + else + { + constexpr index_t kKPack = GetSmemKPackV(); + + constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + + constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock; + + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + + constexpr auto v_lds_block_desc_naive = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor( + v_lds_block_desc_naive, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + return transform_tensor_descriptor( + v_lds_block_desc_permuted, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + }; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(!Problem::kUseTrLoad) + { + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); + } + else + { + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() + { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + constexpr index_t NPerThread = GetAlignmentV(); + constexpr index_t NThreads = kNPerBlock / NPerThread; + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + + constexpr index_t KPerThread = ElemPerThread / NPerThread; + constexpr index_t KThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); } template @@ -257,113 +610,167 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy typename Problem::SaccDataType, Problem::kNumGemm0Warps * get_warp_size(), TileGemmShape, + Problem::BlockFmhaShape::kK0, + Problem::BlockFmhaShape::kQKHeaddim>, typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; - constexpr auto warp_gemm = []() { - if constexpr(get_warp_size() == 64 && - std::is_same_v && - std::is_same_v && + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v) && std::is_same_v) { - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32); - static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32); + constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}); + +#ifdef __gfx950__ + static_assert((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16), + "Not supported WarpGemm sizes!"); +#else + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); +#endif - // TODO: hard coded here. Otherwise, it produces incorrect results - constexpr index_t swizzle_factor = 4; - return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution< - swizzle_factor>{}; - } - else - { - constexpr bool SwizzleA = - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32; return WarpGemmDispatcher{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), - true, // TransposeC - SwizzleA>{}; + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); } }(); + using WarpGemm = remove_cvref_t; + using BlockGemmPolicy = BlockGemmARegBSmemCRegV2CustomPolicy; + WarpGemm>; if constexpr(1 < Problem::kNumGemm0Warps) - return BlockGemmARegBSmemCRegV2{}; + return BlockGemmARegBSmemCRegV2PrefetchK{}; else return BlockGemmARegBSmemCRegOneWarpV1{}; } - // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first - // k_lds bufffer template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() { - constexpr index_t single_k_lds_buffer_size = - GetSmemSizeK() / GetNumKLdsBuffers(); - constexpr index_t single_v_lds_buffer_size = - GetSmemSizeV() / GetNumVLdsBuffers(); + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; - if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size) - return 0; + auto warp_gemm = [&]() { + if constexpr((std::is_same_v || + std::is_same_v) && + std::is_same_v) + { + constexpr index_t WarpGemmM = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + constexpr index_t WarpGemmK = + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}); + + static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) || + (WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)), + "Not supported WarpGemm sizes!"); + + if constexpr((WarpGemmM == 16 && WarpGemmK == 32) || + (WarpGemmM == 32 && WarpGemmK == 16)) + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Double>{}; + else + return WarpGemmDispatcher< + typename Problem::PDataType, + typename Problem::VDataType, + typename Problem::OaccDataType, + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), + Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), + true, + false, + false, + WGAttrNumAccessEnum::Single>{}; + } + else + { + static_assert(false, "Not supported data types!"); + } + }(); + + using WarpGemm = remove_cvref_t; + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + if constexpr(1 < Problem::kNumGemm1Warps) + { + if constexpr(!Problem::kUseTrLoad) + return BlockGemmARegBSmemCRegV2PrefetchN{}; + else + return BlockGemmARegBSmemTrLoadCRegV2PrefetchN{}; + } else - return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); - }; - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() - { - using BlockFmhaShape = remove_cvref_t; - - constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; - constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); - constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); - - constexpr index_t last_v_lds_buffer_offset = - MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * - ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType); - - constexpr index_t first_k_lds_buffer_size = - MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * - sizeof(typename Problem::KDataType); - - return GetExclusiveKLdsBytes() + last_v_lds_buffer_offset < - first_k_lds_buffer_size; - }; - - template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() - { - return MakeKLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::KDataType); + return BlockGemmARegBSmemCRegOneWarpV1{}; } template - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV() { - return MakeVLdsBlockDescriptor().get_element_space_size() * - sizeof(typename Problem::VDataType); - } + constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers(); + + return num_kv_lds_buffers * GetSingleSmemElementSpaceSize() * + max(sizeof(typename Problem::KDataType), sizeof(typename Problem::VDataType)); + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() + { + static_assert(!Problem::kHasDropout, + "BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy does not " + "account for dropout LDS scratch space. Either use a policy " + "that implements dropout shared-memory sizing or disable dropout " + "for this pipeline."); + return 0; + }; template CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { - // assume V can reuse the other shared memory by K except the first - // assume Dropout can reuse the shared memory by V - return GetExclusiveKLdsBytes() + - max(GetSmemSizeK() - GetExclusiveKLdsBytes(), - max(GetSmemSizeV(), GetSmemSizeDropout(0))); + return GetSmemSizeKV() + GetSmemSizeDropout(); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp new file mode 100644 index 0000000000..95f68623fa --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_trload.hpp @@ -0,0 +1,861 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +template +struct BlockFmhaPipelineQRKSVSWholeKPrefetchTrLoad +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; + static_assert(kQLoadOnce == Policy::QLoadOnce); + static_assert(sizeof(KDataType) == sizeof(VDataType) && + alignof(KDataType) == alignof(VDataType), + "K and V share the same LDS region; their element types must have identical " + "size and alignment."); + + static constexpr bool kUseN0Loop = true; + static constexpr bool kIgnoreFastExp2 = true; + static constexpr bool kIsNaiveHDimLoad = true; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kN0Sub = + BlockFmhaShape::kK0; // subdivision of kN0 used in N0-loop, same value as kK0 + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + + static_assert(Problem::kUseTrLoad == true, "Check failed!"); + + static constexpr bool kUseTrLoad = true; + + // since this pipeline is only used by the inference path of xformers, the Dropout function is + // not well tested with the pipeline, so here we have Dropout disabled + static_assert(kHasDropout == false, "Dropout is not supported by this pipeline at present!"); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = + Problem::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async_whole_k_prefetch_trload"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& /* unused */, + const AttentionVariantParams& /* unused */, + const BlockIndices& /* unused */, + void* smem_ptr, + DropoutType& dropout) const + { + // xformers path does not require the pipeline to output random values for host + // verification, since a separate kernel is used to generate random values + ignore = randval_dram_block_window_tmp; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0Sub == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr index_t n0_loops = kN0 / kN0Sub; + constexpr index_t k1_loops = kN0 / kK1; + + // usually kN0 is 128, kN0Sub/kK1 is 32/16 + static_assert(n0_loops >= 2, "n0_loops >= 2 required to use this pipeline"); + static_assert(k1_loops >= 2, "k1_loops >= 2 required to use this pipeline"); + + constexpr auto NumKVLdsBuffers = Policy::template GetNumKVLdsBuffers(); + + constexpr index_t NumPrefetchV = Policy::template GetNumPrefetchV(); + static_assert(n0_loops >= NumPrefetchV, "Check failed!"); + static_assert(k1_loops >= NumPrefetchV, "Check failed!"); + + constexpr bool kPreloadWholeNextIterationK = + Policy::template IsPreloadWholeNextIterationK(); + + // This path prefetches two k_tiles for next iteration, so it has the opportunity to + // prefetch two v_tiles during Gemm0 + if constexpr(!kPreloadWholeNextIterationK) + { + static_assert(NumPrefetchV >= 2); + }; + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + // SaccBlockTile size is [kM0, kK1] + // PcompBlockTile size is [kM0, kN0] + using SaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using CombineSaccBlockTileType = decltype(gemm_0.template MakeCBlockTile()); + using PcompBlockTileType = decltype(cast_tile(CombineSaccBlockTileType{})); + + SaccBlockTileType sacc_tile; + PcompBlockTileType pcomp_tile; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + using MLBlockTileType = decltype(block_tile_reduce( + PcompBlockTileType{}, sequence<1>{}, f_max, CompDataType{0})); + + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + OaccBlockTileType o_acc; + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + if(seqlen_k_end <= seqlen_k_start) + { + clear_tile(o_acc); + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + return o_acc; + }; + + auto k_dram_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeKDramTileDistribution()); + + auto q_tile = load_tile(q_dram_window); + + using k_tile_type = decltype(load_tile(k_dram_window)); + + auto k_tiles = [&]() { + if constexpr(kPreloadWholeNextIterationK) + return statically_indexed_array{}; + else + return statically_indexed_array{}; + }(); + + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + + if constexpr(!kPreloadWholeNextIterationK) + { + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + // provide partition_index for LDS tile window with so that warp_id is in vgpr + array partition_index{get_warp_id(), get_lane_id()}; + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + using k_lds_window_type = decltype(get_slice_tile( + k_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile(k_lds_window, + sequence{}, + sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{}); + }); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kK1, kN1>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {seqlen_k_start, 0}, + Policy::template MakeVDramTileDistribution()); + + const auto f_exp = [&](CompDataType x) { + if constexpr(std::is_same_v) + { + return __expf(x); + } + else + { + return exp(x); + } + }; + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {bias_origin.at(number<0>{}), seqlen_k_start}, + Policy::template MakeBiasDramTileDistribution()); + + // assuming no random values need be saved, this is true when the pipeline is called from + // xformers, since we have a separate kernel to generated random values + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + // need to pass a null_randval_dram and tile window to the BlockDropout operator to + // make it works + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + auto seqlen_k_curr = seqlen_k_start; + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + do + { + // STAGE 1, Gemm_0 ( S = Q@K ) + if constexpr(kPreloadWholeNextIterationK) // used when kM0 = 64 + { + if(seqlen_k_curr == seqlen_k_start) // at first iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + + // prefetch all k_tiles for next iteration + static_for<0, n0_loops, 1>{}([&](auto ii_n0) { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + else // the iteration is also the last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 < n0_loops - 1) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 == n0_loops - 1) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + }; + } + else // at intermediate and last iteration + { + if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + // prefetch k_tile for next iteration + k_tiles[i_n0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + else // last iteration + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + if constexpr(i_n0 == 0) + { + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + block_sync_lds(); + gemm_0( + sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + }; + } + } + else // only preload one unroll of K for next iteration, used when kM0=128 + { + static_for<0, n0_loops, 1>{}([&](auto i_n0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tiles[number{}]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(i_n0 < n0_loops - 2) + { + k_tiles[number{}] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + + if constexpr(i_n0 >= n0_loops - 2) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + block_sync_lds(); + + gemm_0(sacc_tile, q_tile, k_lds_windows[number{}]); + + sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile); + auto tmp_tile = cast_tile(sacc_tile); + set_slice_tile(pcomp_tile, + tmp_tile, + sequence<0, i_n0 * kN0Sub>{}, + sequence{}); + }); + } + + __builtin_amdgcn_sched_barrier(0x000000001); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + + tile_elementwise_inout( + [&](auto& x, const auto y) { + x += type_convert(bias_element_func(y)); + }, + pcomp_tile, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + constexpr auto pcomp_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(pcomp_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(pcomp_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + pcomp_tile.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + pcomp_tile(i_j_idx) *= scale_s; + position_encoding.update(pcomp_tile(i_j_idx), row, col); + }); + }); + } + else + { + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, pcomp_tile); + } + + move_tile_window(bias_dram_window, {0, kN0}); + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + bool need_perpixel_check = mask.IsEdgeTile( + q_origin.at(number<0>{}), seqlen_k_curr, number{}, number{}); + if(need_perpixel_check) + { + set_tile_if(pcomp_tile, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = seqlen_k_curr + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto m_local = block_tile_reduce( + pcomp_tile, sequence<1>{}, f_max, -numeric::infinity()); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; + + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); + + __builtin_amdgcn_sched_barrier(0); + + // check whether first V-LdsBufer overlap with last K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((n0_loops - 1) % NumKVLdsBuffers == 2 % NumKVLdsBuffers) + { + __builtin_amdgcn_s_barrier(); + }; + + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[I0]), + partition_index); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kPreloadWholeNextIterationK) + { + static_for<1, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + } + else + { + static_for<2, NumPrefetchV, 1>{}([&](auto i_k1) { + v_tiles[i_k1] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }); + }; + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto p_spans = decltype(pcomp_tile)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = type_convert(0.0f); + }); + } + else + { + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + pcomp_tile(i_j_idx) = f_exp(pcomp_tile[i_j_idx] - m[i_idx]); + }); + } + }); + + auto rowsum_p = + block_tile_reduce(pcomp_tile, sequence<1>{}, f_sum, CompDataType{0}); + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + // adjust o_acc[] according to the update between m and m_old + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + + if(m[i_idx] == -numeric::infinity()) + { + l(i_idx) = rowsum_p[i_idx]; + } + else + { + const auto tmp = f_exp(m_old[i_idx] - m[i_idx]); + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + } + }); + + __builtin_amdgcn_sched_barrier(0x00000001); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window); + } + + seqlen_k_curr += kN0; + + __builtin_amdgcn_sched_barrier(0x00000001); + + auto p = cast_tile(tile_elementwise_in(p_compute_element_func, pcomp_tile)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + // STAGE 3, Gemm_1 ( O = P@V ) + static_for<0, k1_loops, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + { + v_tiles[number{}] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kK1, 0}); + }; + + if constexpr(i_k1 == k1_loops - NumPrefetchV) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I0] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + + if constexpr(i_k1 == k1_loops - NumPrefetchV + 1) + { + if constexpr(!kPreloadWholeNextIterationK) + { + if(seqlen_k_curr < seqlen_k_end) + { + k_tiles[I1] = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kN0Sub, 0}); + }; + } + }; + + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}]); + + if constexpr(i_k1 < k1_loops - 1) + { + store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]), + partition_index); + }; + }); + + // check whether last V-LdsBuffer overlap with first K-LdsBuffer, + // this does not occur when k1_loops == 2 and NumKVLdsBuffers == 4 + if constexpr((k1_loops - 1 + 2) % NumKVLdsBuffers == 0) + { + __builtin_amdgcn_s_barrier(); + }; + } while(seqlen_k_curr < seqlen_k_end); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + if(m[i_idx] == -numeric::infinity()) + o_acc(i_j_idx) = 0.0f; + else + o_acc(i_j_idx) *= 1.0f / l[i_idx]; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + DropoutType& dropout, + const float sink_v) const + { + ignore = sink_v; + + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index d2d8bb2c7e..9fc3652f51 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -692,8 +692,11 @@ struct BlockFmhaPipelineQSKSVS const AttentionVariantParams& variant_params, const BlockIndices& block_indices, void* smem_ptr, - DropoutType& dropout) const + DropoutType& dropout, + const float sink_v) const { + ignore = sink_v; + return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 71da3767b0..f217f57bad 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -57,7 +57,7 @@ struct TileFmhaShape static constexpr index_t kQKHeaddim = BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at // once (or repeately load Q as a whole tile) - static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); + static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim must be divisible by kK0!"); static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp new file mode 100644 index 0000000000..f84d232196 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_k.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2PrefetchK +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + // hot loop: + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(nIter)(I0) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(I0), + {nIter * NPerBlockPerIter, 0 * KPerBlockPerIter}); + b_warp_tensors[I0] = load_tile(b_warp_windows(nIter)(I0)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + if constexpr(kIter < KIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(nIter)(number{}) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(number{}), + {nIter * NPerBlockPerIter, (kIter + 1) * KPerBlockPerIter}); + b_warp_tensors[number{}] = + load_tile(b_warp_windows(nIter)(number{})); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + if constexpr(kIter == 0) + { + // warp GEMM + c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + } + else + { + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[kIter]); + // WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor_array[nIter]); + }; + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode(); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + static_assert(NWarp == 1, "Check failed!"); + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } + + template + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode(); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp new file mode 100644 index 0000000000..51f59e16c0 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_prefetch_n.hpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemCRegV2PrefetchN +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // constrcut from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, + make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile(b_warp_windows(I0)(I0))); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {0 * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(I0) = load_tile(b_warp_windows(I0)(kIter)); + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {(nIter + 1) * NPerBlockPerIter, kIter * KPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile(b_warp_windows(number{})(kIter)); + }; + + __builtin_amdgcn_sched_barrier(0x00000001); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp new file mode 100644 index 0000000000..c731539134 --- /dev/null +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_trload_creg_v2_prefetch_n.hpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmARegBSmemTrLoadCRegV2PrefetchN +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<1>{}]; + constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iNWarp = get_warp_id() % NWarp; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + // construct from A-block-tensor from A-Block-tensor-tmp + // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent + // distribution + auto a_block_tensor = make_static_distributed_tensor( + MakeABlockTileDistribution()); + + a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer(); + + constexpr auto b_warp_dstr_encode = + typename InputTileDistributionTraits::TransposedDstrEncode{}; + + // construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_block_window_tmp.get_window_origin() + multi_index<2>{0, iNWarp * WG::kN}, + make_static_tile_distribution(b_warp_dstr_encode)); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + // check C-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "wrong!"); + + using AWarpDstr = typename WG::AWarpDstr; + using CWarpDstr = typename WG::CWarpDstr; + + using AWarpTensor = typename WG::AWarpTensor; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + constexpr auto I0 = number<0>{}; + + using b_warp_tensor_type = decltype(load_tile_transpose(b_warp_windows(I0)(I0))); + + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + statically_indexed_array b_warp_tensors; + + // read B warp tensor from B Block window + b_warp_windows(I0)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(I0)(kIter), + {kIter * KPerBlockPerIter, 0 * NPerBlockPerIter}); + b_warp_tensors(I0) = load_tile_transpose(b_warp_windows(I0)(kIter)); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + if constexpr(nIter < NIterPerWarp - 1) + { + // read B warp tensor from B Block window + b_warp_windows(number{})(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(number{})(kIter), + {kIter * KPerBlockPerIter, (nIter + 1) * NPerBlockPerIter}); + b_warp_tensors(number{}) = + load_tile_transpose(b_warp_windows(number{})(kIter)); + }; + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensors[nIter]); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + template + CK_TILE_DEVICE static constexpr auto MakeABlockTileDistribution() + { + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); + + return make_static_tile_distribution(a_block_dstr_encode); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + // constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, + const BBlockWindowTmp& b_block_window_tmp) const + { + auto c_block_tensor = MakeCBlockTile(); + operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); + return c_block_tensor; + } +}; + +} // namespace ck_tile From 26ff0da492b6356970646173fca8e92f69d2534d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 24 Apr 2026 19:22:11 -0700 Subject: [PATCH 9/9] [CK] restore fmha performance reporting and disable c++17 in CI. (#6741) ## Motivation This change restores monitoring of FMHA benchmarks performance in daily builds and removes the std=c++17 flag from CI builds on gfx90a. ## Technical Details ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- Jenkinsfile | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 05eb7f97ef..8675c716e7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -840,8 +840,10 @@ def cmake_build(Map conf=[:]){ if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - archiveArtifacts "perf_fmha_*.log" - stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}" + dir("projects/composablekernel"){ + archiveArtifacts "perf_fmha_*.log" + stash includes: "perf_fmha_**.log", name: "perf_fmha_log_${arch_name}" + } } catch(Exception err){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." @@ -918,7 +920,7 @@ def Build_CK(Map conf=[:]){ sh "projects/composablekernel/script/run_inductor_tests.sh" } // run performance tests, stash the logs, results will be processed on the master node - dir("projects/composablekernel/script"){ + dir("projects/composablekernel/script"){ if (params.RUN_PERFORMANCE_TESTS){ if (params.RUN_FULL_QA && (arch == "gfx90a" || arch == "gfx942")){ // run full tests on gfx90a or gfx942 @@ -1017,6 +1019,13 @@ def process_results(Map conf=[:]){ catch(Exception err){ echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}." } + try{ + unstash "perf_fmha_log_gfx950" + } + catch(Exception err){ + echo "could not locate the FMHA performance logs for gfx950: ${err.getMessage()}." + } + } if (params.BUILD_INSTANCES_ONLY){ // unstash deb packages