From 683194f86aae6ee53dcd1ef9d1b065db3ca59a99 Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Fri, 18 Nov 2022 00:38:13 +0800 Subject: [PATCH] Work around develop validation failure (#513) * workaround bf16 atten fwd issue on gfx908 * typo [ROCm/composable_kernel commit: 892a8d769d95cf85ce5e5cab3432ddb000826588] --- include/ck/ck.hpp | 7 +++++++ ...gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 8 ++++++++ .../test_batched_gemm_softmax_gemm_permute_bf16.cpp | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index ba4a414545..ddaef1db3b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -154,6 +154,13 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 0 +// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue +#ifdef __gfx908__ +#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1 +#else // __gfx90a__, ... +#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0 +#endif // __gfx908__ + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 0e512473d8..c8bc33afa3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -874,6 +874,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle } } // end gemm1 + // workaround compiler issue; see ck/ck.hpp + if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && + is_same_v && MPerBlock == 256 && NPerBlock == 128 && + Gemm1NPerBlock == 128) + { + __builtin_amdgcn_sched_barrier(0); + } + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp index 43c10066b8..e55b37fa9a 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp @@ -29,7 +29,7 @@ TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes) TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16) { this->Run(); } -TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_FPBF_PadM) +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM) { this->lengths_ = std::vector>{ {136, 128, 32, 128, 2, 3},