From fba3d780f242eb3080799c7bfe34d1d336fc2ebe Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Thu, 13 Feb 2025 15:48:15 +0800 Subject: [PATCH] fix bug, function now passes. --- .../impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 10 ++++++---- .../gpu/thread/threadwise_tensor_slice_transfer.hpp | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index cc2a55bea8..7ce78cd7c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -339,19 +339,21 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle 1) { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 6eebed7319..37cbeeb3e9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -377,14 +377,14 @@ struct ThreadwiseTensorSliceTransfer_v2 if constexpr(InvalidElementAsNaN) { - dst_buf(Number{}) = + dst_buf(Number{}) = is_src_valid ? type_convert(src_vector.template AsType()[i]) : NumericLimits::QuietNaN(); } else { - dst_buf(Number{}) = + dst_buf(Number{}) = type_convert(src_vector.template AsType()[i]); } }); @@ -1619,7 +1619,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; }); // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to