From a57f8d8b67cf28547b61dafd2160dffb33b6eaeb Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 29 Dec 2025 23:05:35 +0800 Subject: [PATCH] [CK_TILE] support split-k a16w4 gemm1 (#3389) * initial version to support moe gemm1 split-k * add missing args * fix build warning * update reference * for split-k disable bias and weight * remove debug log * fix format * fix div by zero errors * fix cmake config * update * resolve conflicts * remove useless changes * reformat * fix * remove useless changes * fix ci --------- Co-authored-by: lalala-sh Co-authored-by: root [ROCm/composable_kernel commit: dae85ead64c16b34eaa643d09fb0d6da008ca814] --- example/ck_tile/18_flatmm/CMakeLists.txt | 5 +- .../18_flatmm/mixed_prec/a16w4_flatmm.hpp | 2 +- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 31 +++++++++- .../18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp | 7 ++- .../run_a16w4_moe_flatmm_example.inc | 19 +++--- .../mixed_prec/run_mixed_prec_flatmm.inc | 5 +- .../mxgemm/mx_flatmm_instance.cpp.in | 6 +- .../18_flatmm/mxgemm/run_mx_flatmm.inc | 10 ++-- .../host/reference/reference_moe_gemm.hpp | 32 +++++----- .../ops/flatmm/kernel/flatmm_kernel.hpp | 37 ++++++------ .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 60 ++++++++++++------- 11 files changed, 136 insertions(+), 78 deletions(-) diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 696cb4f60b..7451ee25b0 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -31,13 +31,14 @@ if(has_supported_gpu) add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp) target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - if (GPU_TARGETS MATCHES "gfx95") + if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94") add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp) target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp) target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) - + endif() + if (GPU_TARGETS MATCHES "gfx95") include(mxgemm/mx_flatmm_instance.cmake) mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES) message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}") diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp index 7dc53736b4..fcd60ec1c6 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp @@ -8,7 +8,7 @@ // GEMM config with 16x16 warp tile struct A16W4_FlatmmConfig16 { - static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 64; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index 0678e87e47..fe7fe4c5d1 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -191,13 +191,15 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" + std::cout << "Launching kernel " << Kernel::GetName() << "\n" + << "with args:" << CodegenFlatmmShape::GetName() << "\n" << "Shape: " << CodegenFlatmmShape::GetName() << "\n" << "problem: " << CodegenPipelineProblem::GetName() << "\n" << "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n" << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + << "\n" + << "k_batch: " << kargs.k_batch << std::endl; } if(s.flush_cache_) @@ -471,10 +473,33 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[]) throw std::runtime_error("Unsupported precision type for gemm2!"); } } + else if(gemm_kind == "gemm1_split_k") + { + if(mixed_prec == "fp16xfp4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::half_t, + ck_tile::pk_fp4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{}); + } + else if(mixed_prec == "bf16xfp4") + { + return run_a16w4_moe_gemm_example_with_layouts< + ck_tile::bfloat16_t, + ck_tile::pk_fp4_t, + FlatmmConfig, + ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported precision type for gemm1_split_k!"); + } + } else { throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value " - "[gemm1_gate_up | gemm2]"); + "[gemm1_gate_up | gemm1_split_k | gemm2]"); } } else diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp index 9f4fc152be..bf305124f7 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp @@ -13,7 +13,7 @@ // GEMM config with 16x16 warp tile struct A16W4_FlatmmConfig16 { - static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 32; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 256; @@ -69,7 +69,7 @@ auto create_args(int argc, char* argv[]) .insert("c_layout", "R", "C tensor data layout - Row by default.") .insert("gemm_kind", "gemm1_gate_up", - "Gemm kind in FFN network [gemm1_gate_up | gemm2] - " + "Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - " "gemm1_gate_up by default.") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("warmup", "50", "number of iterations before benchmark the kernel") @@ -80,7 +80,8 @@ auto create_args(int argc, char* argv[]) .insert("warp_tile", "0", "0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)") - .insert("repeat", "10", "number of iterations to benchmark the kernel."); + .insert("repeat", "10", "number of iterations to benchmark the kernel.") + .insert("k_batch", "1", "parallism to control splik-k."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc index 45df126540..228eaf4e3d 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc @@ -67,9 +67,12 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, return -1; }; - using ADataType = PrecActType; - using BDataType = PrecWeightType; - using CDataType = PrecActType; + using ADataType = PrecActType; + using BDataType = PrecWeightType; + using ADataType = PrecActType; + using BDataType = PrecWeightType; + using CDataType = + std::conditional_t; using AccDataType = float; using ScaleType = ck_tile::e8m0_t; @@ -88,6 +91,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, const ck_tile::index_t warmup = arg_parser.get_int("warmup"); const ck_tile::index_t repeat = arg_parser.get_int("repeat"); const ck_tile::index_t experts = arg_parser.get_int("experts"); + const ck_tile::index_t k_batch = arg_parser.get_int("k_batch"); // TODO: replace the magic declaration const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile; @@ -231,14 +235,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, static_cast(expert_weight_dev.GetDeviceBuffer()); auto scale_b_shuffle_dev_ptr = - ck_tile::FlatmmScalePointer{ - static_cast(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; + ck_tile::FlatmmScalePointer{ + static_cast(scale_b_shuffle_dev_buf.GetDeviceBuffer()), + N / ScaleGranularityN}; auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{ static_cast(expert_bias_dev.GetDeviceBuffer()), experts * N}; using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs< ck_tile::FlatmmScalePointer<-1>, - ck_tile::FlatmmScalePointer, + ck_tile::FlatmmScalePointer, ck_tile::FlatmmScalePointer<1>>; MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev, p_sorted_expert_weight_dev, @@ -250,7 +255,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc, num_tokens, experts, topk, - 1, // k_batch + k_batch, // k_batch M, N, K, diff --git a/example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc b/example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc index 552f10348c..e2071e2d55 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc +++ b/example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc @@ -85,8 +85,9 @@ int run_mixed_prec_flatmm_with_layouts(int argc, c_rslt_host.SetZero(); scale_b_dev_buf.ToDevice(scale_b_shuffle.data()); - auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer{ - static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN}; + auto scale_b_dev_ptr = + ck_tile::FlatmmScalePointer{ + static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN}; invoke_mixed_prec_flatmm; -using ScaleN = ck_tile::FlatmmScalePointer; +using ScaleM = ck_tile::FlatmmScalePointer; +using ScaleN = ck_tile::FlatmmScalePointer; template float mx_flatmm_calc{ - static_cast(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM}; - auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer{ - static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; + auto scale_a_dev_ptr = + ck_tile::FlatmmScalePointer{ + static_cast(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM}; + auto scale_b_dev_ptr = + ck_tile::FlatmmScalePointer{ + static_cast(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN}; invoke_mx_flatmm __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, const ck_tile::index_t* p_sorted_expert_ids_, @@ -43,10 +43,11 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, float* scale_B_ptr, float* expert_bias_ptr) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int problem_N = MoeGemmKind == 1 ? N / 2 : N; - int row = idx / problem_N; // Compute row index - int col = idx % problem_N; // Compute column index + constexpr auto is_split_k = MoeGemmKind == 3; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int problem_N = MoeGemmKind == 1 ? N / 2 : N; + int row = idx / problem_N; // Compute row index + int col = idx % problem_N; // Compute column index index_t gather_token_id = 0; index_t scatter_token_id = 0; @@ -203,7 +204,7 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, acc_up += acc_up_temp * scale_A * scale_B_up; float bias = 0.f, bias_up = 0.f; - if(expert_bias_ptr != nullptr) + if(expert_bias_ptr != nullptr && !is_split_k) { bias = expert_bias_ptr[expert_id * N + col]; if constexpr(MoeGemmKind == 1) @@ -221,23 +222,24 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, else { // moe gemm2 don't use activation. - CDataType res = ck_tile::type_convert((acc + bias) * expert_weight_ptr[row]); - using ResV2Type = std::conditional_t, - ck_tile::fp16x2_t, - ck_tile::bf16x2_t>; - ResV2Type add_v{0, 0}; + auto weight = + is_split_k ? ck_tile::type_convert(1.0f) : expert_weight_ptr[row]; + CDataType res = ck_tile::type_convert((acc + bias) * weight); + + thread_buffer add_v = 0; if(c_index % 2) { // result is the second value of fp16 pair. - add_v.y = res; + add_v.template get_as()[1] = res; } else { // result is the first value of fp16 pair. - add_v.x = res; + add_v.template get_as()[0] = res; } // mask last bit to make sure atomicAdd pointer is aligned of DWORD. - atomic_add(reinterpret_cast(C + (c_index & 0xffff'fffe)), add_v); + atomic_add_g(reinterpret_cast(C + (c_index & 0xffff'fffe)), + add_v); } } } @@ -249,7 +251,7 @@ template void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_, const index_t* p_sorted_expert_ids_, diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 09204aa7ed..9a33801c8f 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -28,17 +28,18 @@ struct FlatmmProblem index_t stride_C; }; -template +template struct FlatmmScalePointer { + using ScaleType = ScaleType_; static constexpr int GranularityMN = SharedGranularityMN; static constexpr int GranularityK = SharedGranularityK; - const float* ptr; + const ScaleType* ptr; CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {} - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_) + CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {} + CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_) : ptr(ptr_) { } @@ -57,23 +58,24 @@ struct FlatmmScalePointer return ret; } - CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete; + CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete; }; -template -struct FlatmmScalePointer +template +struct FlatmmScalePointer { + using ScaleType = ScaleType_; static constexpr int GranularityMN = SharedGranularityMN; static constexpr int GranularityK = 0; static_assert(GranularityMN != 0); - const float* ptr; + const ScaleType* ptr; index_t length; CK_TILE_HOST_DEVICE FlatmmScalePointer() = default; - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {} - CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_) + CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {} + CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_) : ptr(ptr_), length(length_) { } @@ -94,7 +96,7 @@ struct FlatmmScalePointer return ret; } - CK_TILE_HOST_DEVICE float operator[](index_t i) const + CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const { // with additional oob check if constexpr(GranularityMN == 1) @@ -105,23 +107,24 @@ struct FlatmmScalePointer }; // shared granularityMN = -1 means no scale -template <> -struct FlatmmScalePointer<-1, 0> +template +struct FlatmmScalePointer<-1, 0, ScaleType_> { + using ScaleType = ScaleType_; static constexpr int GranularityMN = -1; static constexpr int GranularityK = 0; - const float* ptr = nullptr; + const ScaleType* ptr = nullptr; CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default; - CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {} - CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {} + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {} + CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {} CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const { return FlatmmScalePointer{}; } - CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const + CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const { return 1; // alway return 1, it doesn't change the result } diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index cc3306f0fc..b47ec4a829 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -132,6 +132,7 @@ enum class MoeFlatmmKind kFFN_gemm1_gate_only, kFFN_gemm1_gate_up, kFFN_gemm2, + kFFN_gemm1_split_k, }; namespace moe { @@ -222,8 +223,10 @@ struct MoeFlatmmKernel static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); - static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2; - static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up; + static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2; + static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up; + static constexpr bool IsGemm1SplitK = kind == MoeFlatmmKind::kFFN_gemm1_split_k; + static constexpr bool IsBShuffled = true; // static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize; static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock; @@ -395,15 +398,6 @@ struct MoeFlatmmKernel a_k_split_offset = k_id * KRead * kargs.stride_A; } - if constexpr(std::is_same_v) - { - b_k_split_offset = k_id * KRead * kargs.stride_B; - } - else if constexpr(std::is_same_v) - { - b_k_split_offset = k_id * KRead; - } - if(k_id < static_cast(kargs.k_batch - 1)) { splitted_k = KRead; @@ -412,6 +406,22 @@ struct MoeFlatmmKernel { splitted_k = kargs.K - KRead * (kargs.k_batch - 1); } + + if constexpr(IsBShuffled) + { + b_k_split_offset = k_id * splitted_k * NPerXdl; + } + else + { + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * kargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + } } index_t a_k_split_offset; @@ -573,15 +583,16 @@ struct MoeFlatmmKernel return DTesnorIsValid; } - template CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, const BDataType* b_flat_ptr, EDataType* e_ptr, [[maybe_unused]] const AccDataType* exp_weight_ptr, - const int expert_id, + [[maybe_unused]] const int expert_id, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset) { @@ -742,13 +753,13 @@ struct MoeFlatmmKernel { index_t scale_k = BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK; + const auto scale_k_offset = + (splitk_batch_offset.b_k_split_offset / BGranularityK) * K_Pack; index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); - using ScaleType = e8m0_t; - return make_naive_tensor_view( - reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, + scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset, make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK), make_tuple(FlatScaleK, 1), number<8>{}, @@ -1386,11 +1397,16 @@ struct MoeFlatmmKernel if constexpr(!BMXFP4_Pipeline) lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_scale_m[idx] * epi_scale_n[idx]; - if constexpr(EnableBias) - lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx]; - if constexpr(!IsInputGemm) - lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx]; - else // for mlp1 gate-only + if(kind != + MoeFlatmmKind::kFFN_gemm1_split_k) // disable weight and bias for split-k + { + if constexpr(EnableBias) + lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx]; + if constexpr(!IsInputGemm) + lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx]; + } + if constexpr(kind == + MoeFlatmmKind::kFFN_gemm1_gate_only) // for mlp1 gate-only lds_tile[lds_stage].get_thread_buffer()[idx] = ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]); });