From 1e2dac15a16dbe3f0806834853299bdab026c2fa Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Mon, 10 Nov 2025 13:11:13 -0600 Subject: [PATCH] save tmp --- .../core/algorithm/coordinate_transform.hpp | 15 +- .../core/tensor/distributed_lds_tensor.hpp | 164 ++++++++++++++++++ .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 18 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 8 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 2 + shared_banks/test.cpp | 19 ++ shared_banks/test_with_elementwise.cpp | 144 +++++++++++++++ 7 files changed, 365 insertions(+), 5 deletions(-) create mode 100644 include/ck_tile/core/tensor/distributed_lds_tensor.hpp create mode 100644 shared_banks/test.cpp create mode 100644 shared_banks/test_with_elementwise.cpp diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index 7511413bba..dd4a6b08bc 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1326,8 +1326,21 @@ struct xor_t : public base_transform<2, 2> idx_low(number<0>{}) = idx_up[number<0>{}]; + // original (1.0 rate): + // idx_low(number<1>{}) = + // idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]); + // 0.5 rate + // idx_low(number<1>{}) = + // idx_up[number<1>{}] ^ (2 * idx_up[number<0>{}]) % up_lengths_[number<1>{}]; idx_low(number<1>{}) = - idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]); + idx_up[number<1>{}] ^ (2 * idx_up[number<0>{}] % up_lengths_[number<1>{}]); + // if (threadIdx.x < 64) + // { + // printf("lane: %u | idx_low: (%d, %d) | idx_up: (%d, %d)\n", + // __lane_id(), + // idx_low[number<0>{}], idx_low[number<1>{}], + // idx_up[number<0>{}], idx_up[number<1>{}]); + // } } template diff --git a/include/ck_tile/core/tensor/distributed_lds_tensor.hpp b/include/ck_tile/core/tensor/distributed_lds_tensor.hpp new file mode 100644 index 0000000000..19dcb69fdc --- /dev/null +++ b/include/ck_tile/core/tensor/distributed_lds_tensor.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tensor_descriptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/arch/arch.hpp" + +namespace ck_tile +{ + +// A static distributed tensor backed by shared memory (LDS), +// with indexing defined by a naive packed tensor descriptor of the Y-dimensions +// of the provided StaticTileDistribution. +// +// Notes: +// - This does NOT allocate LDS. Callers must provide a __shared__-allocated buffer +// of size get_lds_buffer_size() elements of DataType. +// - Index mapping uses a naive packed descriptor derived from the Y lengths of the +// distribution to compute a linear offset into LDS. +// - Distributed indexing API mirrors static_distributed_tensor (operator[]/operator()). +// +// Example (inside __global__ kernel): +// using Dist = decltype(make_static_tile_distribution(my_encoding{})); +// using Tensor = ck_tile::static_distributed_lds_tensor; +// +// // Compile-time LDS size in elements +// __shared__ float lds[Tensor::get_lds_buffer_size()]; +// +// Tensor t{lds}; +// // write using distributed indices (compile-time indices) +// t(ck_tile::detail::make_tile_distributed_index(ck_tile::sequence<0>{}, +// ck_tile::sequence<0>{})) = 1.0f; +// ck_tile::block_sync_lds(); +// +// // read back similarly +// auto v = t[ck_tile::detail::make_tile_distributed_index(ck_tile::sequence<0>{}, +// ck_tile::sequence<0>{})]; +// +template +struct static_distributed_lds_tensor +{ + using DataType = remove_cvref_t; + using StaticTileDistribution = remove_cvref_t; + + static_assert(StaticTileDistribution::is_static(), + "StaticTileDistribution must be fully known at compile time"); + + // Build a naive packed Y descriptor based on the distribution's Y lengths. + // This mirrors the usage pattern from static_distributed_tensor where offsets are + // computed through the Ys-to-D descriptor, but here we explicitly construct a naive + // packed descriptor to satisfy the requirement. + using YsDescriptorOriginal = + remove_cvref_t; + + using YsNaivePackedDescriptor = + remove_cvref_t; + + using ThreadTensorDesc = YsNaivePackedDescriptor; + + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t kLdsElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); + static_assert(kLdsElementSpaceSize > 0, "Make sure tile distribution is valid for LDS"); + + // Number of DataType elements required in LDS buffer (considering PackedSize for DataType) + CK_TILE_HOST_DEVICE static constexpr index_t get_lds_buffer_size() + { + return kLdsElementSpaceSize / PackedSize; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() + { + return StaticTileDistribution::get_num_of_dimension_x(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_lengths() + { + return StaticTileDistribution::get_lengths(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution() + { + return StaticTileDistribution{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() + { + return StaticTileDistribution::get_distributed_spans(); + } + + // Construct with an LDS pointer (caller-provided __shared__ storage). + CK_TILE_HOST_DEVICE explicit constexpr static_distributed_lds_tensor(DataType* lds_ptr) + : p_lds_{lds_ptr} + { + } + + // Bind or rebind LDS pointer at runtime if needed. + CK_TILE_DEVICE void bind(DataType* lds_ptr) { p_lds_ = lds_ptr; } + + CK_TILE_DEVICE constexpr const DataType* data() const { return p_lds_; } + CK_TILE_DEVICE constexpr DataType* data() { return p_lds_; } + + // Distributed indexing: const access + template + CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const + { + static_assert(is_static_v, + "TileDistributedIndices must be static"); + constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( + TileDistributedIndices{}); + constexpr index_t linear = ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize; + return p_lds_[linear]; + } + + // Distributed indexing: mutable access + template + CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices) + { + static_assert(is_static_v, + "TileDistributedIndices must be static"); + constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( + TileDistributedIndices{}); + constexpr index_t linear = ThreadTensorDesc{}.calculate_offset(y_idx) / PackedSize; + return p_lds_[linear]; + } + + // Returns the naive Y-descriptor (type-only, constexpr-friendly) + CK_TILE_HOST_DEVICE static constexpr auto get_y_naive_packed_descriptor() + { + return ThreadTensorDesc{}; + } + +private: + DataType* p_lds_ = nullptr; // shared (LDS) pointer provided by caller +}; + +// Helper: make function +template +CK_TILE_HOST_DEVICE constexpr auto +make_static_distributed_lds_tensor(DataType* lds_ptr, const StaticTileDistribution&) +{ + return static_distributed_lds_tensor, + remove_cvref_t>{lds_ptr}; +} + +// Utility to query required LDS size (in elements of DataType) for a given distribution +template +CK_TILE_HOST_DEVICE constexpr index_t +get_required_lds_size_elems(const StaticTileDistribution&) +{ + using tensor_t = static_distributed_lds_tensor, + remove_cvref_t>; + return tensor_t::get_lds_buffer_size(); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index b5584f98df..c8837180c4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -59,6 +59,10 @@ struct GemmPipelineAgBgCrImplBase const ElementFunction& element_func) const { const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile); + ignore = lds_tile_window; + ignore = src_block_tile; + ignore = element_func; + ignore = block_tile_tmp; store_tile(lds_tile_window, block_tile_tmp); } @@ -66,6 +70,10 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, const SrcBlockTile& src_block_tile) const { + ignore = lds_tile_window; + ignore = src_block_tile; + // CK_PRINT::StaticTileDistribution::DstrEncode>(); + store_tile(lds_tile_window, src_block_tile); } @@ -74,10 +82,12 @@ struct GemmPipelineAgBgCrImplBase const SrcTileWindow& lds_tile_window, bool_constant = {}) const { - if constexpr(LoadTranspose) - dst_block_tile = load_tile_transpose(lds_tile_window); - else - load_tile(dst_block_tile, lds_tile_window); + // if constexpr(LoadTranspose) + // dst_block_tile = load_tile_transpose(lds_tile_window); + // else + // load_tile(dst_block_tile, lds_tile_window); + ignore = dst_block_tile; + ignore = lds_tile_window; } CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index aaa04615fd..47515e2c9c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -444,12 +444,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr auto b_lds_load_tile_distr = make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + // CK_PRINT(); + // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + // CK_PRINT{}])>::Base::TileDstr::DstrEncode>(); + // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM @@ -478,6 +482,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // CK_PRINT(); + // CK_PRINT{}])>::Base::TileDstr::DstrEncode>(); + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple // as input. move_tile_window(a_copy_dram_window, a_dram_tile_window_step); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index ecff6fe497..cecb9b097c 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -545,6 +545,8 @@ struct UniversalGemmBasePolicy VecLoadSize, getATileAccessPattern(), NumWaveGroups>; + CK_PRINT(); + CK_PRINT::DstrEncode>(); return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: KPerBlock X MPerBlock diff --git a/shared_banks/test.cpp b/shared_banks/test.cpp new file mode 100644 index 0000000000..0ed2b4a6f4 --- /dev/null +++ b/shared_banks/test.cpp @@ -0,0 +1,19 @@ +#include + +__global__ void conflict_simulator() +{ + // __shared__ int buf[64]; + // buf[threadIdx.x] = threadIdx.x; // -> 0 conflicts + // + // __shared__ int buf[2 * 64]; + // buf[2 * threadIdx.x] = threadIdx.x; // -> 1 conflict per access (2-way) + // + __shared__ int buf[4 * 64]; + buf[4 * threadIdx.x] = threadIdx.x; // -> 3 conflicts per access (4-way) +} + +int main() +{ + conflict_simulator<<<1, 64, 0, 0>>>(); + return 0; +} diff --git a/shared_banks/test_with_elementwise.cpp b/shared_banks/test_with_elementwise.cpp new file mode 100644 index 0000000000..f96d631868 --- /dev/null +++ b/shared_banks/test_with_elementwise.cpp @@ -0,0 +1,144 @@ +#include +#include +#include + +using ck_tile::number; +using ck_tile::make_naive_tensor_descriptor; +using ck_tile::make_tile_window; +using ck_tile::make_tensor_view; +using ck_tile::make_tuple; +using ck_tile::address_space_enum; +using ck_tile::fp16_t; +using ck_tile::index_t; + +using ck_tile::tile_distribution_encoding; +using ck_tile::sequence; +using ck_tile::tuple; +using ck_tile::make_static_distributed_tensor; +using ck_tile::make_static_tile_distribution; +using ck_tile::get_n_lds_banks; +using ck_tile::get_n_words_per_128b; +using ck_tile::make_xor_transform; +using ck_tile::make_pass_through_transform; +using ck_tile::make_unmerge_transform; + +constexpr index_t kBlockSize = 256; + +template +__device__ constexpr auto make_lds_tensor_descriptor() +{ + constexpr auto DataTypeSize = sizeof(DataType); + constexpr index_t KPack = 16 / sizeof(DataType); + + constexpr auto MLdsLayer = + ck_tile::max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; +} + +template +__device__ +constexpr auto make_register_distribution() +{ + using ck_tile::tile_distribution_encoding_pattern_2d; + using ck_tile::tile_distribution_pattern; + constexpr auto VecLoadSize = 16 / sizeof(DType); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); +} + +template +__global__ void lds_write_simulator() +{ + __shared__ int buf[160000 / sizeof(int)]; + + // lds setup + auto lds_tensor_descriptor = make_lds_tensor_descriptor(); + auto lds_tensor_view = make_tensor_view(reinterpret_cast(buf), lds_tensor_descriptor); + auto lds_write_window = make_tile_window(lds_tensor_view, make_tuple(number{}, number{}), {0, 0}); + + // register setup + auto reg_tile_dst = make_register_distribution(); + auto reg_tensor = make_static_distributed_tensor(reg_tile_dst); + + // writeout + store_tile(lds_write_window, reg_tensor); + + ck_tile::block_sync_lds(); +} + +__device__ +constexpr auto make_lds_read_distribution() +{ + return make_static_tile_distribution(tile_distribution_encoding< + ck_tile::sequence<2>, + ck_tile::tuple, ck_tile::sequence<4, 4, 4>>, + ck_tile::tuple, ck_tile::sequence<2, 1>>, + ck_tile::tuple, ck_tile::sequence<1, 2>>, + ck_tile::sequence<1, 2, 2>, + ck_tile::sequence<0, 0, 2>>{}); +} + +template +__global__ void lds_read_simulator() +{ + __shared__ int buf[160000 / sizeof(int)]; + + auto lds_tensor_descriptor = make_lds_tensor_descriptor(); + auto lds_tensor_view = make_tensor_view(reinterpret_cast(buf), lds_tensor_descriptor); + + constexpr auto reg_tile_dst = make_lds_read_distribution(); + auto lds_read_window = make_tile_window(lds_tensor_view, make_tuple(number{}, number{}), {0, 0}, reg_tile_dst); + + [[maybe_unused]] auto reg_tile = load_tile(lds_read_window); + ck_tile::block_sync_lds(); +} + +int main() +{ + constexpr auto kGrid = 1; + constexpr auto kBlockM = 128; + constexpr auto kBlockK = 64; + lds_write_simulator<<>>(); + lds_read_simulator<<>>(); + return 0; +}