diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 96678d275a..39212d2904 100755 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -38,6 +38,12 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) +set(GEMM_OPTIONS) +list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16") +list(APPEND GEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS}) +target_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS}) + list(APPEND gpu_list gfx942 gfx950) set(target 0) diff --git a/example/01_gemm/gemm_xdl_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp8_v3.cpp index da891267b2..55a6c60273 100644 --- a/example/01_gemm/gemm_xdl_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_v3.cpp @@ -28,10 +28,10 @@ using DeviceGemmV2Instance = ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 256, - 224, 256, + 256, 256, 128, 16, 16, 16, 16, - 7, 8, + 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 171a232c0f..fe66d320e0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -178,7 +178,7 @@ struct BlockwiseGemmXdlops_pipeline_v3 constexpr auto GetMfma() { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else return MfmaInstr::mfma_f32_32x32x16f8f8; +#endif } template <> @@ -1132,7 +1136,11 @@ struct MfmaSelector template <> constexpr auto GetMfma() { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif } template <> diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 9c40d923d3..466116242f 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -954,11 +954,11 @@ struct vector_type()>> StaticallyIndexedArray d32x4_; StaticallyIndexedArray d64x2_; StaticallyIndexedArray d128x1_; - } data_; + } data_ = {d128_t{0}}; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } template __host__ __device__ constexpr const auto& AsType() const @@ -1082,11 +1082,11 @@ struct vector_type()>> StaticallyIndexedArray d64x4_; StaticallyIndexedArray d128x2_; StaticallyIndexedArray d256x1_; - } data_; + } data_ = {d256_t{0}}; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } template __host__ __device__ constexpr const auto& AsType() const