diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp old mode 100755 new mode 100644 diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 973006196b..25fab6bde0 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -55,17 +55,17 @@ struct GemmConfig #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; static constexpr bool DoubleSmemBuffer = false; #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index b4ea5d22c0..79ed9ce76b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -402,7 +402,6 @@ int run_gemm_example_with_layouts(int argc, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 2c29814b73..bd7a0566a2 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -24,9 +24,14 @@ using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; +#else using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, @@ -49,10 +54,16 @@ using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl>>; // bf16 - using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; @@ -87,9 +97,14 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; +#else using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, @@ -113,10 +128,16 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl +struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; // FP16 template struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 @@ -188,6 +251,69 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 } }; +template +struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4 {