diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index ba4a414545..ddaef1db3b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -154,6 +154,13 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 0 +// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue +#ifdef __gfx908__ +#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1 +#else // __gfx90a__, ... +#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0 +#endif // __gfx908__ + namespace ck { enum struct InMemoryDataOperationEnum diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index 0e512473d8..c8bc33afa3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -874,6 +874,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle } } // end gemm1 + // workaround compiler issue; see ck/ck.hpp + if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && + is_same_v && MPerBlock == 256 && NPerBlock == 128 && + Gemm1NPerBlock == 128) + { + __builtin_amdgcn_sched_barrier(0); + } + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp index 43c10066b8..e55b37fa9a 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp @@ -29,7 +29,7 @@ TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes) TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16) { this->Run(); } -TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_FPBF_PadM) +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM) { this->lengths_ = std::vector>{ {136, 128, 32, 128, 2, 3},