mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
fixed fp8 issues (#894)
* fixed fp8 init; and reference gemm
* Update host_tensor_generator.hpp
* fixed convert
* fixed reference gemm
* fixed comments
* fixed comments
* fixed ci
* fixed computeType
---------
Co-authored-by: Jing Zhang <jizha@amd.com>
[ROCm/composable_kernel commit: a66d14edf2]
This commit is contained in:
@@ -14,18 +14,22 @@ using ComputeDataType = float;
|
||||
|
||||
struct YElementOp
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
static_assert(ck::is_same<T, float>::value || ck::is_same<T, double>::value ||
|
||||
ck::is_same<T, ck::half_t>::value,
|
||||
static_assert(ck::is_same<X, float>::value || ck::is_same<X, double>::value ||
|
||||
ck::is_same<X, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
T a;
|
||||
static_assert(ck::is_same<Y, float>::value || ck::is_same<Y, double>::value ||
|
||||
ck::is_same<Y, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
X a;
|
||||
|
||||
ck::tensor_operation::element_wise::Sigmoid{}(a, x);
|
||||
|
||||
y = x * a;
|
||||
y = ck::type_convert<Y>(x * a);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -144,7 +144,8 @@ template <typename ALayout,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
typename ComputeDataType = EDataType>
|
||||
struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -243,11 +244,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
|
||||
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
|
||||
|
||||
using ComputeDataType = EDataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
|
||||
@@ -27,6 +27,12 @@ struct PassThrough
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
@@ -81,6 +87,12 @@ struct PassThrough
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
@@ -416,14 +428,19 @@ struct Swish
|
||||
{
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
|
||||
is_same<T, ck::half_t>::value,
|
||||
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
|
||||
is_same<X, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
y = x / (ck::type_convert<T>(1) + ck::math::exp(-beta_ * x));
|
||||
static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
|
||||
is_same<Y, ck::half_t>::value,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
float bx = -beta_ * type_convert<float>(x);
|
||||
y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
|
||||
};
|
||||
|
||||
float beta_ = 1.0f;
|
||||
|
||||
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v;
|
||||
DstData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_vector.template AsType<DstData>()(i) = type_convert<DstData>(v);
|
||||
dst_vector.template AsType<DstData>()(i) = v;
|
||||
});
|
||||
|
||||
const bool is_dst_valid =
|
||||
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v;
|
||||
DstData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -20,7 +20,8 @@ template <typename ADataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
typename CElementwiseOperation,
|
||||
typename ComputType = ADataType>
|
||||
struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a;
|
||||
BDataType v_b;
|
||||
ComputType v_a;
|
||||
ComputType v_b;
|
||||
|
||||
// use PassThrough instead of ConvertBF16RTN for reference calculation
|
||||
if constexpr(is_same_v<AElementwiseOperation,
|
||||
|
||||
@@ -83,8 +83,8 @@ bool profile_gemm_multiply_add_impl(int do_verification,
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 0.2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.1, 0.1});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user