mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4 * fix: removed extra parameter from grouped gemm example instance * fix: FP8 check incorrectly enabling FP8 on RDNA3
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/quantization_operation.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -236,8 +237,9 @@ struct MultiplyAdd
|
||||
const half_t& d0,
|
||||
const half_t& d1) const
|
||||
{
|
||||
const half_t y = type_convert<half_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
const half_t y =
|
||||
type_convert<half_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
|
||||
@@ -245,8 +247,9 @@ struct MultiplyAdd
|
||||
const bhalf_t& d0,
|
||||
const bhalf_t& d1) const
|
||||
{
|
||||
const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
|
||||
e = y;
|
||||
const bhalf_t y =
|
||||
type_convert<bhalf_t>(c * type_convert<float>(d0) + type_convert<float>(d1));
|
||||
e = y;
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
|
||||
|
||||
Reference in New Issue
Block a user