Implement grouped gemm tile loop for RDNA4 (#3304)

* feat: grouped gemm tile loop support for RDNA4

* fix: removed extra parameter from grouped gemm example instance

* fix: FP8 check incorrectly enabling FP8 on RDNA3
This commit is contained in:
Erwin Terpstra
2026-01-13 07:14:23 +01:00
committed by GitHub
parent 141f77aa12
commit eb041079a3
44 changed files with 3067 additions and 1223 deletions

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck {
namespace tensor_operation {
@@ -236,8 +237,9 @@ struct MultiplyAdd
const half_t& d0,
const half_t& d1) const
{
const half_t y = type_convert<half_t>(c) * d0 + d1;
e = y;
const half_t y =
type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
e = y;
}
template <>
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
@@ -245,8 +247,9 @@ struct MultiplyAdd
const bhalf_t& d0,
const bhalf_t& d1) const
{
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
e = y;
const bhalf_t y =
type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
e = y;
}
template <>
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,