From 5827d0d892980f58921dfbf422c35aae0dda3c10 Mon Sep 17 00:00:00 2001 From: music-dino <111048524+music-dino@users.noreply.github.com> Date: Tue, 20 Jan 2026 16:25:30 +0100 Subject: [PATCH] Batched gemm softmax gemm descriptor fix (#3564) * Add rocm to prefix path for codegen * Fix issue with c0_matrix_mask construction [ROCm/composable_kernel commit: 6300ad3c62298dc6fdddfcf19ecd074f7f08fa96] --- .../impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 35b2f54f58..e3a990bcb1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1059,7 +1059,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle c_grid_desc_m_n)}, has_main_k_block_loop{GridwiseGemm64::CalculateHasMainKBlockLoop( a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, - c0_matrix_mask{c.GetLength(I1)}, + c0_matrix_mask{b.GetLength(I0)}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, b1_element_op{b1_element_op_},