diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index ad74ee847e..ce25585341 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -140,9 +140,9 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "Direct load transfer does not support datatypes conversion. Source and " "destination data types must be the same."); - static_assert( - DstVectorDim == nDim - 1, - "Direct load transfer requires the destination vector dimension to be the last one."); + // static_assert( + // DstVectorDim == nDim - 1, + // "Direct load transfer requires the destination vector dimension to be the last one."); static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, "When loading more than one element per thread at once, the contiguous " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index d1015ee504..2f1b73caaf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -933,12 +933,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 #endif // B matrix in LDS memory, dst of blockwise copy - if constexpr(DirectLoad) + if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 2) { return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), make_tuple(BK1Number, Number{}, I1)); } + else if constexpr(DirectLoad && BBlockTransferSrcVectorDim == 1) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{}, I1, Number{})); + } else if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { // bank conflict when writting the data into LDS, but don't worry, we have whole entire @@ -1633,7 +1639,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, - 2, + BBlockTransferSrcVectorDim, // enforcer earlier BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, make_multi_index(num_bk0_per_block * k_idx, n_block_data_idx_on_grid, 0),