From 551be3cb6750170db58c730bab64b138bb4954c6 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Sat, 1 Jun 2024 00:46:41 -0500 Subject: [PATCH] Post-merge fix of PR 1300 (#1313) * add f8 gemm with multiD for both row/col wise * change compute_type to fp8 * changed tuning parameters in the example * add rcr example * post-merge fix * fix * reduce init range [ROCm/composable_kernel commit: 6fb1f4e03fef8a80ae8b5f139b9d4750e2f1a972] --- .../gemm_multiply_multiply_xdl_fp16.cpp | 12 ++++++------ .../device_gemm_multiple_d_xdl_cshuffle_v3.hpp | 2 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 14 +++++++------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp index b0e75a5594..c584ff20cf 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp @@ -59,7 +59,7 @@ struct MultiplyMultiply { const float x0_f = c * d0 * d1; - e = ck::type_convert(x0_f); + e = ck::type_convert(x0_f); } }; @@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ck::index_t K = 4096; ck::index_t StrideA = K; - ck::index_t StrideB = N; + ck::index_t StrideB = K; ck::index_t StrideD = 0; ck::index_t StrideE = N; @@ -164,10 +164,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 2275d83641..c2b5317dd9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -83,7 +83,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD -struct GridwiseGemm_xdl_cshuffle_v3 +struct GridwiseGemmMultiD_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)),