From fa39c4e7987acb39d3bb1f3c74add5acda44e164 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Wed, 21 May 2025 12:34:30 -0500 Subject: [PATCH 1/3] Add Doxygen Documentation for HostTesnor, HostTensorDescriptor, DeviceMem, FillUniformDistribution (#2160) * added documentation for HostTensorDescriptor * added documentation for DeviceMem and FillUniformDistribution * fixed merging error * fixed host_tensor_descriptor error * clang format --- include/ck/library/utility/fill.hpp | 14 +++++ include/ck_tile/host/device_memory.hpp | 31 +++++++++- include/ck_tile/host/fill.hpp | 20 ++++++- include/ck_tile/host/host_tensor.hpp | 81 +++++++++++++++++++++++++- 4 files changed, 139 insertions(+), 7 deletions(-) diff --git a/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp index 35625d142e..4f421b4282 100644 --- a/include/ck/library/utility/fill.hpp +++ b/include/ck/library/utility/fill.hpp @@ -85,6 +85,20 @@ struct FillUniformDistributionIntegerValue } }; +/** + * @brief A functor for filling a container with a monotonically increasing or decreasing sequence. + * + * FillMonotonicSeq generates a sequence of values starting from an initial value + * and incrementing by a fixed step for each subsequent element. + * + * @tparam T The numeric type of the sequence elements. + * + * Example usage: + * ``` + * std::vector v(5); + * FillMonotonicSeq{10, 2}(v); // Fills v with {10, 12, 14, 16, 18} + * ``` + */ template struct FillMonotonicSeq { diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 13684c0e24..587f38987e 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -20,10 +20,35 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) } /** - * @brief Container for storing data in GPU device memory + * @brief Manages device memory allocation and host-device data transfers * + * DeviceMem encapsulates GPU memory management operations using HIP runtime API. + * It provides functionality for allocating device memory, transferring data between + * host and device, and performing basic memory operations. + * + * Key features: + * - Automatic memory allocation and deallocation + * - Host-to-device and device-to-host data transfers + * - Memory initialization operations + * - Integration with HostTensor for simplified data handling + * + * Usage example: + * ``` + * // Allocate device memory + * BHostTensor AHostData({256}); + * DeviceMem d_mem(BHostData.get_element_space_size_in_bytes()); + * + * // Transfer data to device + * HostTensor AHostTensor({256}); + * d_mem.ToDevice(AHostData.data()); + * + * // Retrieve data from device + * HostTensor ResultHostTensor({256}); + * d_mem.FromDevice(ResultHostTensor.data()); + * ``` */ struct DeviceMem + { DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem(std::size_t mem_size) : mMemSize(mem_size) @@ -163,8 +188,8 @@ struct DeviceMem } } - void* mpDeviceBuf; - std::size_t mMemSize; + void* mpDeviceBuf; ///< pointer to device buffer + std::size_t mMemSize; ///< size of device buffer in bytes }; } // namespace ck_tile diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 3f64eb28cd..4a359e031f 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -17,13 +17,31 @@ namespace ck_tile { +/** + * @brief Functor for filling a range with randomly generated values from a uniform distribution. + * + * This struct provides functionality to fill iterators or ranges with random values + * generated from a uniform distribution. It supports both single-threaded and + * multi-threaded operation. + * + * @tparam T The target type for the generated values. + * + * @note The multi-threaded implementation is not guaranteed to provide perfectly + * distributed values across threads. + * + * @example + * + * // Direct usage without creating a separate variable: + * ck_tile::FillUniformDistribution{-1.f, 1.f}(a_host_tensor); + */ template struct FillUniformDistribution { float a_{-5.f}; float b_{5.f}; std::optional seed_{11939}; - // ATTENTION: threaded does not guarantee the distribution between thread + // ATTENTION: Whether to use multi-threading (note: not guaranteed to be perfectly distributed + // across threads). bool threaded = false; template diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index a43877c6da..deaa158d50 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -85,6 +85,19 @@ CK_TILE_HOST auto construct_f_unpack_args(F, T args) return construct_f_unpack_args_impl(args, std::make_index_sequence{}); } +/** + * @brief Descriptor for tensors in host memory. + * + * HostTensorDescriptor manages the shape (dimensions) and memory layout (strides) + * of a tensor in host memory. It provides functionality to: + * - Store tensor dimensions and strides + * - Calculate default strides for contiguous memory layout + * - Convert multi-dimensional indices to linear memory offsets + * - Query tensor metadata (dimensions, element counts, etc.) + * + * The class supports both automatic stride calculation for contiguous memory layout + * and custom strides for more complex memory patterns. + */ struct HostTensorDescriptor { HostTensorDescriptor() = default; @@ -138,12 +151,35 @@ struct HostTensorDescriptor } std::size_t get_num_of_dimension() const { return mLens.size(); } + /** + * @brief Calculates the total number of elements in the tensor. + * + * Computes the product of all dimension lengths to determine the + * total element count in the tensor. + * + * @pre The lengths array (mLens) and strides array (mStrides) must have + * the same size. + * + * @return The total number of elements in the tensor. + */ std::size_t get_element_size() const { assert(mLens.size() == mStrides.size()); return std::accumulate( mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies()); } + /** + * @brief Calculates the total element space required for the tensor in memory. + * + * This method computes the minimum size of contiguous memory needed to store + * all elements of the tensor, taking into account the tensor's dimensions and + * strides. The calculation is based on the formula: 1 + max((length_i - 1) * stride_i) + * across all dimensions. + * + * Dimensions with length 0 are skipped in this calculation. + * + * @return The size of the tensor's element space (number of elements). + */ std::size_t get_element_space_size() const { std::size_t space = 1; @@ -165,6 +201,18 @@ struct HostTensorDescriptor const std::vector& get_strides() const { return mStrides; } + /** + * @brief Calculates the linear offset from multi-dimensional indices. + * + * Converts a set of N-dimensional indices into a single linear offset by computing + * the inner product of the indices with the tensor's strides. + * + * @tparam Is Parameter pack of index types (should be convertible to std::size_t) + * @param is Variable number of indices, one for each dimension of the tensor + * @return std::size_t Linear offset corresponding to the given multi-dimensional indices + * + * @pre The number of indices must match the number of dimensions in the tensor + */ template std::size_t GetOffsetFromMultiIndex(Is... is) const { @@ -173,6 +221,15 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } + /** + * @brief Calculates the linear memory offset from a multi-dimensional index + * + * Computes the linear offset by performing an inner product between the provided + * multi-dimensional indices and the tensor's strides. + * + * @param iss Vector containing the multi-dimensional indices + * @return The calculated linear offset as a size_t + */ std::size_t GetOffsetFromMultiIndex(std::vector iss) const { return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); @@ -194,8 +251,8 @@ struct HostTensorDescriptor } private: - std::vector mLens; - std::vector mStrides; + std::vector mLens; ///< Lengths of each dimension + std::vector mStrides; ///< Strides for each dimension }; template @@ -681,6 +738,24 @@ struct HostTensor Data mData; }; +/** + * @brief Creates a host tensor descriptor with specified dimensions and layout + * + * Constructs a HostTensorDescriptor with appropriate strides based on whether the tensor + * layout is row-major or column-major. This is determined via the compile-time template + * parameter `is_row_major`. + * + * @tparam is_row_major Compile-time flag indicating if the layout is row-major (true) or + * column-major (false) + * + * @param row Number of rows in the tensor + * @param col Number of columns in the tensor + * @param stride Stride between adjacent rows (for row-major) or columns (for column-major) + * + * @return HostTensorDescriptor with shape {row, col} and strides: + * - For row-major: {stride, 1} + * - For column-major: {1, stride} + */ template auto host_tensor_descriptor(std::size_t row, std::size_t col, @@ -698,6 +773,7 @@ auto host_tensor_descriptor(std::size_t row, return HostTensorDescriptor({row, col}, {1_uz, stride}); } } + template auto get_default_stride(std::size_t row, std::size_t col, @@ -718,5 +794,4 @@ auto get_default_stride(std::size_t row, else return stride; } - } // namespace ck_tile From ebc5a6ef8717aba5ea5ee691de8e0da0dc4de04e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 21 May 2025 22:47:34 +0200 Subject: [PATCH 2/3] Grouped conv bwd wei add for larger filter and Merge Groupes optimization (#2197) * Grouped conv bwd wei add two stage instances for larger filter and Merge Groups * Fix * fix * Restore removed instances --------- Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> --- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 108 ++++++++++++++++-- .../grouped_convolution_backward_weight.hpp | 16 +++ ...rouped_convolution_backward_weight_xdl.inc | 96 ++++++++++++++++ .../grouped_conv2d_bwd_weight/CMakeLists.txt | 4 + ...ngchw_gkcyx_ngkhw_bf16_pipev1_instance.cpp | 2 +- ...gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp | 41 +++++++ ..._ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp | 2 +- ..._gkcyx_ngkhw_f16_pipev1_part2_instance.cpp | 41 +++++++ ...nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp | 2 +- ...gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp | 41 +++++++ ..._nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp | 2 +- ..._gkyxc_nhwgk_f16_pipev1_part2_instance.cpp | 41 +++++++ .../grouped_conv3d_bwd_weight/CMakeLists.txt | 4 + ...wgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp | 2 +- ...zyxc_ndhwgk_bf16_pipev1_part2_instance.cpp | 41 +++++++ ...hwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp | 2 +- ...kzyxc_ndhwgk_f16_pipev1_part2_instance.cpp | 41 +++++++ ...dhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp | 2 +- ...czyx_ngkdhw_bf16_pipev1_part2_instance.cpp | 41 +++++++ ...cdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp | 2 +- ...kczyx_ngkdhw_f16_pipev1_part2_instance.cpp | 41 +++++++ 21 files changed, 552 insertions(+), 20 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 0ed12b984b..fbcda3ca57 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -72,14 +72,31 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + // clang-format on + >; +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_part2_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4> // clang-format on >; @@ -145,15 +162,34 @@ using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> + // clang-format on + >; +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_part2_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 4, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 16, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 4, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1> - // clang-format on - >; + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4> + // clang-format on + >; template , S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8 ,1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, F16, F16, 8, 1> + // clang-format on + >; +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_part2_instances = std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, F16, F16, 4, 4>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, F16, F16, 2, 2>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, F16, F16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, F16, F16, 1, 4>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, F16, F16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, F16, F16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, F16, F16, 1, 4> // clang-format on >; @@ -292,14 +351,39 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instance DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8 ,1>, DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 1>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 4>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 8, 1> - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, BF16, BF16, 4, 4>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, BF16, BF16, 2, 2>, - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> // clang-format on >; +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_part2_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, S<8, 8, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 8, 8, false, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 2, 8, false, 1, 1, S<1, 8, 1, 32>, 2, Scheduler, PipelineVersion, 1, BF16, BF16, 2, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 32, 32, 1, 1, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 1, 1, 8, false, 1, 1, S<1, 4, 1, 64>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 256, 32, 8, 16, 16, 1, 16, S<4, 2, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 8, BF16, BF16, 1, 4>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 4, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 2, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 128, 32, 8, 16, 16, 1, 8, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, false, 1, 1, S<1, 4, 1, 16>, 1, Scheduler, PipelineVersion, 4, BF16, BF16, 1, 4> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index e8e46a7329..a450307dc2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -383,6 +383,8 @@ struct DeviceOperationInstanceFactory>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_bf16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev2_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_part2_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp index c28de81134..3897eac117 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pi // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances< 2, NGCHW, GKCYX, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..a832d9c3e9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/ngchw_gkcyx_ngkhw/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkcyx_ngkhw_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_part2_instances< + 2, + NGCHW, + GKCYX, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp index 6e77488299..f09e9c8479 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..051d8b17ac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_part2_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp index e2ecee734f..480b84960d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..bf6492a820 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_part2_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 1b0d2dd0b2..5574cf82f9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -26,6 +26,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp + xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_irregular_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instance.cpp xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_irregular_instance.cpp @@ -44,6 +46,8 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev5_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp + xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp ) if(DL_KERNELS) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp index 4c4589d128..8dc563e079 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..07221a7af5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_part2_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp index 125b324985..0b96c12198 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..2de899e66d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_part2_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp index e7cfcf1e5f..1514cb1c6c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf1 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_instances< 3, NGCDHW, GKCZYX, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..f451708158 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_bf16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_part2_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp index f22b0c74c0..dd7309eb62 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances< + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances< 3, NGCDHW, GKCZYX, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp new file mode 100644 index 0000000000..9eb492d07f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/ngcdhw_gkczyx_ngkdhw/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkczyx_ngkdhw_f16_pipev1_part2_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_part2_instances< + 3, + NGCDHW, + GKCZYX, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 534d4594d0906288339c937b6419f5995ef09889 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 22 May 2025 01:28:00 -0500 Subject: [PATCH 3/3] Refactor tile_window.hpp, tile_window_linear.hpp into a CK Tile Hierarchy (#2214) * window_origin variable now in base class * abstracted more functions * consolidated tile_window_static_distribution and tile_window_static_lengths * clang format * skeleton code for tile_window and tile_window_linear consolidation * more abstraction * moved variables from child to parent * clang format * removed comments * removed debug code * removed debug code * abstracting traits WIP * consolidated traits * removed comments and clang formatted --- include/ck_tile/core.hpp | 1 + include/ck_tile/core/tensor/tile_window.hpp | 540 +++++------------ .../ck_tile/core/tensor/tile_window_base.hpp | 256 +++++++++ .../core/tensor/tile_window_linear.hpp | 544 ++++++------------ 4 files changed, 571 insertions(+), 770 deletions(-) create mode 100644 include/ck_tile/core/tensor/tile_window_base.hpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 2ea8bf15a7..aa9411b2e1 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,6 +54,7 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_scatter_gather.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 716b1f4ecb..d8a5c14f9b 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -34,166 +35,60 @@ template struct tile_window_with_static_distribution + : public tile_window_with_tile_dstr_base< + tile_window_with_static_distribution, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_> { - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using TileDstr = remove_cvref_t; - - using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; - - using DataType = remove_cvref_t; - - static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); - static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - - static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); - static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + using Base = tile_window_with_tile_dstr_base< + tile_window_with_static_distribution, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_>; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static_assert(NumCoord == 1); - // TODO: check WindowLengths and StaticTileDistribution are consistent - - static_assert(ck_tile::is_known_at_compile_time::value, - "wrong! lengths should be static"); - static_assert(TileDstr::is_static(), "wrong!"); - - static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), - "wrong! inconsistent # of diemsnions"); - - using AdaptorTopIndex = array; - using BottomTensorIndex = array; - - using WindowAdaptorCoord = - decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); - - using BottomTensorCoord = - decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); - - struct load_store_traits - { - private: - static constexpr auto get_vector_dim_y_scalar_per_vector() - { - const auto [ys_vector_lengths, ys_vector_strides] = - tile_window_with_static_distribution:: - get_window_adaptor_ys_safe_vector_length_strides(); - - index_t VectorDimY_ = 0; - index_t ScalarPerVector_ = 1; - - for(index_t i = 0; i < NDimY; ++i) - { - if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) - { - ScalarPerVector_ = ys_vector_lengths[i]; - VectorDimY_ = i; - } - } - - return make_tuple(VectorDimY_, ScalarPerVector_); - } - - public: - static constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); - static constexpr index_t ScalarPerVector = - get_vector_dim_y_scalar_per_vector().template at<1>(); - - // using vector_type_t = vector_type_maker_t; - // using vector_t = typename vector_type_t::type; - using vector_t = thread_buffer; - - private: - static constexpr auto scalars_per_access_ = [] { - constexpr auto scalars_per_access_arr = generate_array( - [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); - - /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() - constexpr auto NDimY_ = NDimY; - - return TO_SEQUENCE(scalars_per_access_arr, NDimY_); - }(); - - static constexpr auto get_space_filling_curve() - { - constexpr auto tile_dstr = TileDstr{}; - - constexpr auto thread_tensor_lengths_ys = - to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); - - // FIXME: need logic to judge dim access order - using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; - - return space_filling_curve{}; - } - - public: - using SFC_Ys = decltype(get_space_filling_curve()); - - static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); - - static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); - static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); - }; - - static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + static_assert(Base::Traits::NumAccess % NumCoord == 0, + "wrong! # of access is not divisible by NumCoord"); + static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord; CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default; CK_TILE_DEVICE constexpr tile_window_with_static_distribution( - const BottomTensorView& bottom_tensor_view, - const WindowLengths& window_lengths, - const BottomTensorIndex& window_origin, - const TileDstr& tile_distribution) - : bottom_tensor_view_{bottom_tensor_view}, - window_lengths_{window_lengths}, - window_origin_{window_origin}, - tile_dstr_{tile_distribution}, - pre_computed_coords_{} + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : pre_computed_coords_{} { -#if 0 // debug - // TODO: this use more register for FA, but less register for GEMM - // need investigation - // only support warp-tile and block-tile - static_assert(NDimP == 1 or NDimP == 2, "wrong!"); - WindowAdaptorCoord window_adaptor_thread_coord_tmp; - - if constexpr(NDimP == 1) - { - window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); - } - else if constexpr(NDimP == 2) - { - window_adaptor_thread_coord_tmp = - make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), - AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); - } -#else - // TODO: this use less register for FA, but more register for GEMM - // need investigation + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; + this->tile_dstr_ = tile_distribution; const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), container_concat(detail::get_partition_index(tile_distribution), - array{0})); -#endif + array{0})); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up // future load/store() calls (might allocate more registers) - using Traits = load_store_traits; + using Traits = typename Base::Traits; using SFC_Ys = typename Traits::SFC_Ys; static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -204,9 +99,10 @@ struct tile_window_with_static_distribution SFC_Ys::get_step_between(number<0>{}, number{}); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); pre_computed_coords_(iCoord) = @@ -214,95 +110,12 @@ struct tile_window_with_static_distribution }); } - CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } - - CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() - { - return TileDstr::is_static(); - } - - CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } - - CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } - - CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } - - CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } - - CK_TILE_DEVICE constexpr void - set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) - { - bottom_tensor_view_.buf_.p_data_ = data; - } - - // move thread's window adaptor coordinate and bottom tensor coordinate - // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] - template - CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( - WindowAdaptorCoord& window_adaptor_thread_coord, - BottomTensorCoord& bottom_tensor_thread_coord, - const ATopIndex& idx_diff_adaptor_top) const - { - array idx_diff_adaptor_bottom; - - move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), - window_adaptor_thread_coord, - idx_diff_adaptor_top, - idx_diff_adaptor_bottom); - - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), - bottom_tensor_thread_coord, - idx_diff_adaptor_bottom); - } - - // return vector dimension among [y0, y1, ...] - CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() - { - // bottom tensor top dimension vector lengths and strides - const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = - BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); - - // window vector lengths/strides - const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; - const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; - - // window adaptor [p0, p1, ..., y0, y1, ...] - array window_adaptor_vector_lengths{ - -1}; - array window_adaptor_vector_strides{ - -1}; - - constexpr auto window_adaptor_bottom_dims = - WindowAdaptor::get_bottom_dimension_hidden_ids(); - - set_container_subset(window_adaptor_vector_lengths, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_lengths); - set_container_subset(window_adaptor_vector_strides, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_strides); - - const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = - WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( - window_adaptor_vector_lengths, window_adaptor_vector_strides); - - // [y0, y1, ...] - constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; - - return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), - get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); - } - - CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } - template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const { - constexpr auto tile_dstr = TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, number{}, bool_constant{}); return dst_tensor; } @@ -314,11 +127,11 @@ struct tile_window_with_static_distribution number = {}, bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -334,9 +147,8 @@ struct tile_window_with_static_distribution // read from bottom tensor const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( + this->get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, 0, bool_constant{}); -#if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( @@ -344,33 +156,26 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / Traits::PackedSize]; + vec_value + .template get_as()[j / Traits::PackedSize]; }); -#else - constexpr index_t d = - tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); - static_assert(d % Traits::ScalarPerVector == 0); - - dst_tensor.get_thread_buffer().template get_as()( - number{}) = bit_cast(vec_value); -#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -386,22 +191,16 @@ struct tile_window_with_static_distribution bool_constant = {}, bool_constant = {}) const { - using Traits = load_store_traits; - - // using vector_type_t = typename Traits::vector_type_t; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; static constexpr index_t YElementSize = - TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0); using vectorized_tbuf = array; - // StaticBuffer; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); @@ -427,7 +226,7 @@ struct tile_window_with_static_distribution Traits::PackedSize; static_assert(d % Traits::ScalarPerVector == 0); - get_bottom_tensor_view().template get_vectorized_elements_raw( + this->get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, 0 /**/, @@ -444,10 +243,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -492,9 +291,8 @@ struct tile_window_with_static_distribution const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory(m0_init_value); // This should be wave independent - using Traits = load_store_traits; + using Traits = typename Base::Traits; - // using vector_type_t = typename Traits::vector_type_t; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; @@ -516,7 +314,7 @@ struct tile_window_with_static_distribution }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements_raw( + this->get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, 0, pre_nop_); // move thread coordinate @@ -525,10 +323,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); m0_inc_with_memory(size_per_issue); @@ -569,7 +367,7 @@ struct tile_window_with_static_distribution const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; @@ -588,7 +386,7 @@ struct tile_window_with_static_distribution constexpr auto iAccess = number{}; // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( + this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, 0, bool_constant{}); // move thread coordinate @@ -597,10 +395,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); smem += size_per_issue; // Note we manually increase the per-issue offset @@ -610,17 +408,18 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, number = {}, bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; // using vector_type_t = typename Traits::vector_type_t; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -643,20 +442,20 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // const vector_t vec_value = vec.template get_as().template at<0>(); // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( + this->get_bottom_tensor_view().template set_vectorized_elements( bottom_tensor_thread_coord, 0, vec_value, @@ -668,10 +467,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -679,15 +478,17 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void store_raw(const static_distributed_tensor& dstr_tensor, - number = {}) const + CK_TILE_DEVICE void + store_raw(const static_distributed_tensor& + dstr_tensor, + number = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; static constexpr bool oob_conditional_check = true; // loop over thread tensor space [y0, y1, ...] @@ -710,16 +511,16 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view() + this->get_bottom_tensor_view() .template set_vectorized_elements_raw( bottom_tensor_thread_coord, 0, vec_value); @@ -729,10 +530,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -740,16 +541,18 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -772,18 +575,18 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements( + this->get_bottom_tensor_view().template update_vectorized_elements( bottom_tensor_thread_coord, 0, vec_value, @@ -795,10 +598,10 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); @@ -806,17 +609,19 @@ struct tile_window_with_static_distribution } template - CK_TILE_DEVICE void update_raw(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update_raw(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) const { - using Traits = load_store_traits; + using Traits = typename Base::Traits; using vector_t = typename Traits::vector_t; using SFC_Ys = typename Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -839,18 +644,18 @@ struct tile_window_with_static_distribution return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; - vec_value.template get_as()(j / Traits::PackedSize) = + vec_value.template get_as()(j / Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements_raw( + this->get_bottom_tensor_view().template update_vectorized_elements_raw( bottom_tensor_thread_coord, 0, vec_value, @@ -863,70 +668,44 @@ struct tile_window_with_static_distribution constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); } }); }); } - // move thread's botom tensor coordiante - // [x0', x1', ... ] ==> [offset] - // also move window-origin - CK_TILE_DEVICE void move(const BottomTensorIndex& step) + // Custom move behavior + CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step) { - window_origin_ += step; - static_for<0, NumCoord, 1>{}([&](auto iCoord) { - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), pre_computed_coords_(iCoord)(I1), step); }); } - CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { - window_origin_ = new_window_origin; - -#if 0 // debug - // TODO: this use more register for FA, but less register for GEMM - // need investigation - // only support warp-tile and block-tile - static_assert(NDimP == 1 or NDimP == 2, "wrong!"); - - WindowAdaptorCoord window_adaptor_thread_coord_tmp; - - if constexpr(NDimP == 1) - { - window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); - } - else if constexpr(NDimP == 2) - { - window_adaptor_thread_coord_tmp = - make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), - AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); - } -#else // TODO: this use less register for FA, but more register for GEMM // need investigation const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - tile_dstr_.get_ps_ys_to_xs_adaptor(), - container_concat(detail::get_partition_index(tile_dstr_), array{0})); -#endif + this->tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(this->tile_dstr_), + array{0})); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up // future load/store() calls (might allocate more registers) - using Traits = load_store_traits; + using Traits = typename Base::Traits; using SFC_Ys = typename Traits::SFC_Ys; static_for<0, NumCoord, 1>{}([&](auto iCoord) { @@ -937,9 +716,10 @@ struct tile_window_with_static_distribution SFC_Ys::get_step_between(number<0>{}, number{}); constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); pre_computed_coords_(iCoord) = @@ -947,27 +727,11 @@ struct tile_window_with_static_distribution }); } - CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } - - // this is the bottom tensor view - // [x0', x1', ...] ==> [offset] - BottomTensorView bottom_tensor_view_; - - // - WindowLengths window_lengths_; - - // origin ([x0', x1', ...]) of window on bottom tensor - BottomTensorIndex window_origin_; - - // Tile tensor distribution, which contains: - // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] - // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] - TileDstr tile_dstr_; - // this contains: // per-thread coordinate for window adaptor // per-thread coordinate for bottom tensor - array, NumCoord> pre_computed_coords_; + array, NumCoord> + pre_computed_coords_; }; // TODO: use strategy @@ -1037,62 +801,26 @@ CK_TILE_DEVICE void move_tile_window( */ template struct tile_window_with_static_lengths + : public tile_window_base, + BottomTensorView_, + WindowLengths_> { - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; - using DataType = typename BottomTensorView::DataType; - - static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - - static_assert(ck_tile::is_known_at_compile_time::value, - "wrong! lengths should be static"); - - using BottomTensorIndex = array; + using Base = + tile_window_base, + BottomTensorView_, + WindowLengths_>; CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default; CK_TILE_DEVICE constexpr tile_window_with_static_lengths( - const BottomTensorView& bottom_tensor_view, - const WindowLengths& window_lengths, - const BottomTensorIndex& window_origin) - : bottom_tensor_view_{bottom_tensor_view}, - window_lengths_{window_lengths}, - window_origin_{window_origin} + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin) { + this->window_origin_ = window_origin; + this->window_lengths_ = window_lengths; + this->bottom_tensor_view_ = bottom_tensor_view; } - - CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } - - CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } - - CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } - - CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } - - CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) - { - window_origin_ = new_window_origin; - } - - CK_TILE_DEVICE constexpr void - set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) - { - bottom_tensor_view_.buf_.p_data_ = data; - } - - // move window-origin - CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } - - // this is the bottom tensor view - // [x0', x1', ...] ==> [offset] - BottomTensorView bottom_tensor_view_; - - // - WindowLengths window_lengths_; - - // origin ([x0', x1', ...]) of window on bottom tensor - BottomTensorIndex window_origin_; }; template diff --git a/include/ck_tile/core/tensor/tile_window_base.hpp b/include/ck_tile/core/tensor/tile_window_base.hpp new file mode 100644 index 0000000000..89a928a53c --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window_base.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/** + * @brief This class provides description of tile windowed view on the device memory. + * + * @note This class does not provide any functions to read or modify device memory. + * + * @tparam BottomTensorView_ Class describing & holding device tensor memory. + * @tparam WindowLengths_ Spatial sizes of windowed view on tensor. + */ +template +struct tile_window_base +{ + + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + using DataType = remove_cvref_t; + + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + using BottomTensorIndex = array; + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + + // Delegate to child if it implements extra logic + static_cast(this)->set_window_origin_extended(new_window_origin); + } + // Default no-op; can be overridden in child + CK_TILE_DEVICE void set_window_origin_extended(const BottomTensorIndex&) {} + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + + // Delegate to child if it implements extra movement logic + static_cast(this)->move_extended(step); + } + + // Default no-op; can be overridden in child + CK_TILE_DEVICE void move_extended(const BottomTensorIndex&) {} + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + WindowLengths window_lengths_; + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; +}; + +template +struct tile_window_with_tile_dstr_base + : public tile_window_base +{ + using TileDstr = remove_cvref_t; + using TileWindowBase = tile_window_base; + + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + using AdaptorTopIndex = array; + // using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = decltype(make_tensor_coordinate( + typename TileWindowBase::BottomTensorDesc{}, typename TileWindowBase::BottomTensorIndex{})); + + static_assert(TileDstr::is_static(), "wrong!"); + static_assert(TileWindowBase::NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + CK_TILE_HOST_DEVICE void init_raw() { this->bottom_tensor_view_.init_raw(); } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + struct Traits + { + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_window_with_tile_dstr_base::get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + using vector_t = + thread_buffer; + + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto thread_tensor_lengths_ys = + to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + }; + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + TileWindowBase::BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return Traits::NumAccess; } + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 5ecaf5ca17..f11610d658 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/type_traits.hpp" @@ -37,171 +38,48 @@ namespace ck_tile { // TODO: if using this struct, better use load_raw()/store_raw(), can control // the the immediate offset on the fly // space-filing-curve is non-snaked here! -// +// This struct inherits from tile_window_with_tile_dstr_base, which is an intermediary base class +// with the ultimate parent class being tile_window_base. template struct tile_window_linear + : public tile_window_with_tile_dstr_base, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_> { + using Base = tile_window_with_tile_dstr_base, + BottomTensorView_, + WindowLengths_, + StaticTileDistribution_>; - using BottomTensorView = remove_reference_t; - using WindowLengths = remove_cvref_t; - using TileDstr = remove_cvref_t; - - using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; - using BottomTensorDesc = typename BottomTensorView::TensorDesc; - - using DataType = remove_cvref_t; using LinearBottomDims = remove_cvref_t; - static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension()); - - static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); - static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - - static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); - static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension()); static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; - // TODO: check WindowLengths and StaticTileDistribution are consistent - - static_assert(ck_tile::is_known_at_compile_time::value, - "wrong! lengths should be static"); - static_assert(TileDstr::is_static(), "wrong!"); - - static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), - "wrong! inconsistent # of diemsnions"); - - using AdaptorTopIndex = array; - using BottomTensorIndex = array; - - using WindowAdaptorCoord = - decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); - - using BottomTensorCoord = - decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); - struct traits { - private: - // return vector dimension among [y0, y1, ...] - CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() - { - // bottom tensor top dimension vector lengths and strides - const auto [bottom_tensor_top_dim_vector_lengths, - bottom_tensor_top_dim_vector_strides] = - BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); - - // window vector lengths/strides - const auto window_adaptor_bottom_dim_vector_lengths = - bottom_tensor_top_dim_vector_lengths; - const auto window_adaptor_bottom_dim_vector_strides = - bottom_tensor_top_dim_vector_strides; - - // window adaptor [p0, p1, ..., y0, y1, ...] - array - window_adaptor_vector_lengths{-1}; - array - window_adaptor_vector_strides{-1}; - - constexpr auto window_adaptor_bottom_dims = - WindowAdaptor::get_bottom_dimension_hidden_ids(); - - set_container_subset(window_adaptor_vector_lengths, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_lengths); - set_container_subset(window_adaptor_vector_strides, - window_adaptor_bottom_dims, - window_adaptor_bottom_dim_vector_strides); - - const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = - WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( - window_adaptor_vector_lengths, window_adaptor_vector_strides); - - // [y0, y1, ...] - constexpr auto y_dims = - typename arithmetic_sequence_gen::type{}; - - return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), - get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); - } - - static constexpr auto get_vector_dim_y_scalar_per_vector() - { - const auto [ys_vector_lengths, ys_vector_strides] = - get_window_adaptor_ys_safe_vector_length_strides(); - - index_t VectorDimY_ = 0; - index_t ScalarPerVector_ = 1; - - for(index_t i = 0; i < NDimY; ++i) - { - if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) - { - ScalarPerVector_ = ys_vector_lengths[i]; - VectorDimY_ = i; - } - } - - return make_tuple(VectorDimY_, ScalarPerVector_); - } - - public: - static constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); - static constexpr index_t ScalarPerVector = - get_vector_dim_y_scalar_per_vector().template at<1>(); - - using vector_t = thread_buffer; - - private: - static constexpr auto scalars_per_access_ = [] { - constexpr auto scalars_per_access_arr = generate_array( - [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); - - /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() - constexpr auto NDimY_ = NDimY; - - return TO_SEQUENCE(scalars_per_access_arr, NDimY_); - }(); - - static constexpr auto get_space_filling_curve() - { - constexpr auto thread_tensor_lengths_ys = - to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths()); - - // FIXME: need logic to judge dim access order - using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; - - return space_filling_curve{}; - } - - public: - using SFC_Ys = decltype(get_space_filling_curve()); - - static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); - - static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); - private: static constexpr auto get_num_non_linear_access() { - constexpr auto sfc_access_lens = SFC_Ys::access_lengths; - using ys_to_rhs_major = - typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; + using ys_to_rhs_major = typename decltype( + typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear = [&]() { index_t cnt = 1; - static_for<0, NDimY, 1>{}([&](auto i_dim_y) { + static_for<0, Base::NDimY, 1>{}([&](auto i_dim_y) { constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; constexpr auto target_h_dim = number{}; // no r dim here! if constexpr(LinearBottomDims{}[target_h_dim] == 0) @@ -230,20 +108,20 @@ struct tile_window_linear // -> prefixsum : seqneuce<0, 2, 4, 6, 8> static constexpr auto get_non_linear_access_map() { - constexpr auto sfc_access_lens = SFC_Ys::access_lengths; - using ys_to_rhs_major = - typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; + using ys_to_rhs_major = typename decltype( + typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear_map = [&]() { - array m_{0}; + array m_{0}; index_t cumulative_len_ = 1; index_t cumulative_non_linear_len_ = 1; - static_for<0, NDimY, 1>{}([&](auto i_y) { - constexpr auto i_dim_y = number{}; // from right to left + static_for<0, Base::NDimY, 1>{}([&](auto i_y) { + constexpr auto i_dim_y = number{}; // from right to left constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y]; constexpr auto target_h_dim = number{}; // no r dim here! constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim]; - array current_m_{0}; + array current_m_{0}; constexpr auto current_len_ = sfc_access_lens[i_dim_y]; // copy cumulative length as current pattern @@ -266,13 +144,12 @@ struct tile_window_linear return m_; }(); - return TO_SEQUENCE(non_linear_map, NumAccess); + return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess); } static constexpr auto get_non_linear_access_histogram() { constexpr auto m_ = get_non_linear_access_map(); - // m_.foo(); constexpr auto r_ = typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{}; @@ -296,7 +173,7 @@ struct tile_window_linear using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum()); }; - static constexpr index_t NumAccess = traits::NumAccess; + static constexpr index_t NumAccess = Base::Traits::NumAccess; static constexpr index_t NumAccess_NonLinear = traits::NumAccess_NonLinear; using AccessMap_NonLinear = typename traits::AccessMap_NonLinear; using AccessHistogram_NonLinear = typename traits::AccessHistogram_NonLinear; @@ -304,30 +181,31 @@ struct tile_window_linear CK_TILE_DEVICE constexpr tile_window_linear() = default; - CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view, - const WindowLengths& window_lengths, - const BottomTensorIndex& window_origin, - const TileDstr& tile_distribution) - : bottom_tensor_view_{bottom_tensor_view}, - window_lengths_{window_lengths}, - window_origin_{window_origin}, - tile_dstr_{tile_distribution}, - cached_coords_{}, - cached_flags_{} + CK_TILE_DEVICE constexpr tile_window_linear( + const typename Base::BottomTensorView& bottom_tensor_view, + const typename Base::WindowLengths& window_lengths, + const typename Base::BottomTensorIndex& window_origin, + const typename Base::TileDstr& tile_distribution) + : cached_coords_{}, cached_flags_{} { + this->bottom_tensor_view_ = bottom_tensor_view; + this->window_lengths_ = window_lengths; + this->window_origin_ = window_origin; + this->tile_dstr_ = tile_distribution; auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( tile_distribution.get_ps_ys_to_xs_adaptor(), - container_concat(make_tuple(get_warp_id(), get_lane_id()), - generate_tuple([&](auto) { return number<0>{}; }, number{}))); + container_concat( + make_tuple(get_warp_id(), get_lane_id()), + generate_tuple([&](auto) { return number<0>{}; }, number{}))); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // future load/store() calls (might allocate more registers) - using SFC_Ys = typename traits::SFC_Ys; + using SFC_Ys = typename Base::Traits::SFC_Ys; static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto non_linear_id = number{}; @@ -343,16 +221,16 @@ struct tile_window_linear // cached flag is independent from non-linear-coord // but need be updated in move_tile, with proper dims cached_flags_(i_access) = coordinate_has_valid_offset_assuming_top_index_is_valid( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp); if constexpr(i_access != (NumAccess - 1)) { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord_tmp, bottom_tensor_thread_coord_tmp, idx_diff_ps_ys); @@ -360,54 +238,13 @@ struct tile_window_linear }); } - CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } - - CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() - { - return TileDstr::is_static(); - } - - CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } - - CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } - - CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } - - CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } - - CK_TILE_DEVICE constexpr void - set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) - { - bottom_tensor_view_.buf_.p_data_ = data; - } - - // move thread's window adaptor coordinate and bottom tensor coordinate - // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] - template - CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( - WindowAdaptorCoord& window_adaptor_thread_coord, - BottomTensorCoord& bottom_tensor_thread_coord, - const ATopIndex& idx_diff_adaptor_top) const - { - array idx_diff_adaptor_bottom; - - move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), - window_adaptor_thread_coord, - idx_diff_adaptor_top, - idx_diff_adaptor_bottom); - - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), - bottom_tensor_thread_coord, - idx_diff_adaptor_bottom); - } - template CK_TILE_DEVICE static constexpr auto get_bottom_linear_coordinate(number) { - using SFC_Ys = typename traits::SFC_Ys; + using SFC_Ys = typename Base::Traits::SFC_Ys; constexpr auto idx_ys = SFC_Ys::get_index(number{}); - using ys_to_rhs_major = - typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = typename decltype( + typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto modified_idx_ys = generate_tuple( [&](auto i_dim_y) { @@ -422,9 +259,9 @@ struct tile_window_linear return number{}; } }, - number{}); + number{}); - constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor(); + constexpr auto adaptor_ = typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(); constexpr auto idx_ = container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys); @@ -441,8 +278,8 @@ struct tile_window_linear { // this case usually is a LDS window, everything is known at compile tile. // we directly use BottomTensorView transform to compute the offset, in case padding - auto bottom_tensor_coord = - make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord); + auto bottom_tensor_coord = make_tensor_coordinate( + typename Base::BottomTensorView{}.get_tensor_descriptor(), linear_coord); return bottom_tensor_coord.get_offset(); } else @@ -453,7 +290,7 @@ struct tile_window_linear // since that would introduce runtime length (so can't use linear offset) constexpr index_t linear_offset = [&]() { constexpr auto x_idx_ = linear_coord; - constexpr auto x_len_ = TileDstr{}.get_lengths(); + constexpr auto x_len_ = typename Base::TileDstr{}.get_lengths(); static_assert(x_idx_.size() == x_len_.size()); constexpr index_t x_dims_ = x_idx_.size(); index_t cu_stride_ = 1; @@ -469,17 +306,16 @@ struct tile_window_linear } } - CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; } - template CK_TILE_DEVICE auto load(number = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; - auto dst_tensor = make_static_distributed_tensor(tile_dstr); + auto dst_tensor = + make_static_distributed_tensor(tile_dstr); auto issue = [&](auto i_access_) { constexpr auto IAccess = number{}; @@ -492,35 +328,29 @@ struct tile_window_linear // read from bottom tensor const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( + this->get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, bool_constant{}); -#if 1 + // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; + return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j) + : idx_diff_ys[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< + typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; }); -#else - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); - static_assert(d % traits::ScalarPerVector == 0); - - dst_tensor.get_thread_buffer().template get_as()( - number{}) = bit_cast(vec_value); -#endif }; WINDOW_DISPATCH_ISSUE(); @@ -533,10 +363,10 @@ struct tile_window_linear number = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // auto dst_tensor = make_static_distributed_tensor(tile_dstr); @@ -551,35 +381,28 @@ struct tile_window_linear // read from bottom tensor const vector_t vec_value = - get_bottom_tensor_view().template get_vectorized_elements( + this->get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, bool_constant{}); -#if 1 // data index [y0, y1, ...] constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess); // write into distributed tensor - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj]; + return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j) + : idx_diff_ys[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - dst_tensor.get_thread_buffer().template at() = - vec_value.template get_as()[j / traits::PackedSize]; + dst_tensor.get_thread_buffer().template at() = vec_value.template get_as< + typename Base::DataTypeDataType>()[j / Base::Traits::PackedSize]; }); -#else - constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); - static_assert(d % traits::ScalarPerVector == 0); - - dst_tensor.get_thread_buffer().template get_as()( - number{}) = bit_cast(vec_value); -#endif }; WINDOW_DISPATCH_ISSUE(); @@ -596,15 +419,17 @@ struct tile_window_linear bool_constant = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; static constexpr index_t YElementSize = - TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); - static_assert(YElementSize % (traits::PackedSize * traits::ScalarPerVector) == 0); + typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) == + 0); using vectorized_tbuf = - array; + array; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); @@ -612,7 +437,7 @@ struct tile_window_linear constexpr auto IAccess = number{}; constexpr auto pre_nop_ = [&]() { if constexpr(pre_nop && i_access_ == 0 && - BottomTensorView::buffer_view::get_address_space() == + Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global) return bool_constant{}; else @@ -628,11 +453,11 @@ struct tile_window_linear constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) / - traits::PackedSize; - static_assert(d % traits::ScalarPerVector == 0); + Base::Traits::PackedSize; + static_assert(d % Base::Traits::ScalarPerVector == 0); - get_bottom_tensor_view().template get_vectorized_elements_raw( - dst_vec_tbuf.template at(), + this->get_bottom_tensor_view().template get_vectorized_elements_raw( + dst_vec_tbuf.template at(), bottom_tensor_thread_coord, linear_offset /**/, bottom_tensor_flag, @@ -663,7 +488,7 @@ struct tile_window_linear // currently we only support everything is non linear dim // actually it's not performant if we have linear dim(e.g. fast changing) static_assert(NumAccess_NonLinear == NumAccess); - static_assert(BottomTensorView::buffer_view::get_address_space() == + static_assert(Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global); // issues * warps * lanes @@ -689,7 +514,7 @@ struct tile_window_linear const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); m0_set_with_memory(m0_init_value); // This should be wave independent - using vector_t = typename traits::vector_t; + using vector_t = typename Base::Traits::vector_t; LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; @@ -708,7 +533,7 @@ struct tile_window_linear auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements_raw( + this->get_bottom_tensor_view().template async_get_vectorized_elements_raw( smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_); // move thread coordinate @@ -732,7 +557,7 @@ struct tile_window_linear // currently we only support everything is non linear dim // actually it's not performant if we have linear dim(e.g. fast changing) static_assert(NumAccess_NonLinear == NumAccess); - static_assert(BottomTensorView::buffer_view::get_address_space() == + static_assert(Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global); // issues * warps * lanes @@ -757,7 +582,7 @@ struct tile_window_linear const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); - using vector_t = typename traits::vector_t; + using vector_t = typename Base::Traits::vector_t; // TODO: we force CK_TILE_LDS_ADDR CK_TILE_LDS_ADDR LdsDataType* smem = @@ -771,7 +596,7 @@ struct tile_window_linear auto bottom_tensor_flag = cached_flags_[IAccess]; // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( + this->get_bottom_tensor_view().template async_get_vectorized_elements( smem, bottom_tensor_thread_coord, 0, @@ -789,15 +614,16 @@ struct tile_window_linear } template - CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, number = {}, bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -812,22 +638,23 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template set_vectorized_elements( + this->get_bottom_tensor_view().template set_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, @@ -839,13 +666,15 @@ struct tile_window_linear } template - CK_TILE_DEVICE void store_raw(const static_distributed_tensor& dstr_tensor, - number = {}) const + CK_TILE_DEVICE void + store_raw(const static_distributed_tensor& + dstr_tensor, + number = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; static constexpr bool oob_conditional_check = true; // loop over thread tensor space [y0, y1, ...] @@ -861,20 +690,21 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + Base::Traits::PackedSize; + vec_value.template get_as()(j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view() + this->get_bottom_tensor_view() .template set_vectorized_elements_raw( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value); }; @@ -883,15 +713,17 @@ struct tile_window_linear } template - CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -907,22 +739,24 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + vec_value.template get_as()( + j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements( + this->get_bottom_tensor_view().template update_vectorized_elements( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, @@ -934,16 +768,18 @@ struct tile_window_linear } template - CK_TILE_DEVICE void update_raw(const static_distributed_tensor& dstr_tensor, - number = {}, - bool_constant = {}, - bool_constant = {}) const + CK_TILE_DEVICE void + update_raw(const static_distributed_tensor& + dstr_tensor, + number = {}, + bool_constant = {}, + bool_constant = {}) const { - using vector_t = typename traits::vector_t; - using SFC_Ys = typename traits::SFC_Ys; + using vector_t = typename Base::Traits::vector_t; + using SFC_Ys = typename Base::Traits::SFC_Ys; - constexpr auto tile_dstr = TileDstr{}; + constexpr auto tile_dstr = typename Base::TileDstr{}; // loop over thread tensor space [y0, y1, ...] auto issue = [&](auto i_access_) { @@ -959,22 +795,24 @@ struct tile_window_linear // read from distributed tensor vector_t vec_value; - static_for<0, traits::ScalarPerVector, traits::PackedSize>{}([&](auto j) { + static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](auto j) { constexpr auto idx_ys = generate_tuple( [&](auto jj) { - return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; + return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; }, - number{}); + number{}); constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / - traits::PackedSize; + Base::Traits::PackedSize; - vec_value.template get_as()(j / traits::PackedSize) = + vec_value.template get_as()( + j / Base::Traits::PackedSize) = dstr_tensor.get_thread_buffer().template at(); }); // write into bottom tensor - get_bottom_tensor_view().template update_vectorized_elements_raw( + this->get_bottom_tensor_view().template update_vectorized_elements_raw( bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, @@ -985,14 +823,10 @@ struct tile_window_linear WINDOW_DISPATCH_ISSUE(); } - - // move thread's botom tensor coordiante - // [x0', x1', ... ] ==> [offset] - // also move window-origin - CK_TILE_DEVICE void move(const BottomTensorIndex& step) + // *_extended() functions acts like a virtual function with a default implementation exisiting + // in the base class + CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex& step) { - window_origin_ += step; - static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto IAccess = number{}; constexpr auto non_linear_id = number{}; @@ -1001,7 +835,7 @@ struct tile_window_linear if constexpr(need_update_non_linear_coord) { - move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(), cached_coords_(non_linear_id), step); } @@ -1010,30 +844,29 @@ struct tile_window_linear auto tmp_coords = cached_coords_[non_linear_id]; constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess); move_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord); + this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord); cached_flags_(IAccess) = coordinate_has_valid_offset_assuming_top_index_is_valid( - bottom_tensor_view_.get_tensor_descriptor(), tmp_coords); + this->bottom_tensor_view_.get_tensor_descriptor(), tmp_coords); }); } - CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex&) { - window_origin_ = new_window_origin; - auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( - TileDstr{}.get_ps_ys_to_xs_adaptor(), - container_concat(make_tuple(get_warp_id(), get_lane_id()), - generate_tuple([&](auto) { return number<0>{}; }, number{}))); + typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor(), + container_concat( + make_tuple(get_warp_id(), get_lane_id()), + generate_tuple([&](auto) { return number<0>{}; }, number{}))); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( - bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); // future load/store() calls (might allocate more registers) - using SFC_Ys = typename traits::SFC_Ys; + using SFC_Ys = typename Base::Traits::SFC_Ys; static_for<0, NumAccess, 1>{}([&](auto i_access) { constexpr auto non_linear_id = number{}; @@ -1049,10 +882,10 @@ struct tile_window_linear { constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number constexpr auto idx_diff_ps_ys = container_concat( - generate_tuple([&](auto) { return number<0>{}; }, number{}), + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); - move_window_adaptor_and_bottom_tensor_thread_coordinate( + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( window_adaptor_thread_coord_tmp, bottom_tensor_thread_coord_tmp, idx_diff_ps_ys); @@ -1060,26 +893,9 @@ struct tile_window_linear }); } - CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } - - // this is the bottom tensor view - // [x0', x1', ...] ==> [offset] - BottomTensorView bottom_tensor_view_; - - // - WindowLengths window_lengths_; - - // origin ([x0', x1', ...]) of window on bottom tensor - BottomTensorIndex window_origin_; - - // Tile tensor distribution, which contains: - // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] - // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] - TileDstr tile_dstr_; - // this contains: - array cached_coords_; - array cached_flags_; + array cached_coords_; + array cached_flags_; }; #undef WINDOW_DISPATCH_ISSUE