diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp index b0e75a5594..c584ff20cf 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp @@ -59,7 +59,7 @@ struct MultiplyMultiply { const float x0_f = c * d0 * d1; - e = ck::type_convert(x0_f); + e = ck::type_convert(x0_f); } }; @@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ck::index_t K = 4096; ck::index_t StrideA = K; - ck::index_t StrideB = N; + ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; @@ -164,10 +164,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 2275d83641..c2b5317dd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -83,7 +83,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD -struct GridwiseGemm_xdl_cshuffle_v3 +struct GridwiseGemmMultiD_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)),