mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Pack the last conversion in a loop
This commit is contained in:
@@ -401,8 +401,7 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
|
||||
{
|
||||
f4x32_t f4x32_array;
|
||||
f4x2_t fp4x2[16];
|
||||
} value{};
|
||||
value.f4x32_array = x;
|
||||
} value{x};
|
||||
float2_t op;
|
||||
float32_t ret;
|
||||
float f_scale = type_convert<float>(scale);
|
||||
|
||||
@@ -1002,110 +1002,16 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
|
||||
f4x32_t f4x32_array;
|
||||
f4x2_t fp4x2[16];
|
||||
} value{x};
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} bitwise_value{};
|
||||
float2_t op;
|
||||
float32_t ret;
|
||||
float scale = 1.0f;
|
||||
// TODO: pack in a loop
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[0];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[0] = op[0];
|
||||
ret[1] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[1];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[2] = op[0];
|
||||
ret[3] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[2];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[4] = op[0];
|
||||
ret[5] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[3];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[6] = op[0];
|
||||
ret[7] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[4];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[8] = op[0];
|
||||
ret[9] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[5];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[10] = op[0];
|
||||
ret[11] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[6];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[12] = op[0];
|
||||
ret[13] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[7];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[14] = op[0];
|
||||
ret[15] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[8];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[16] = op[0];
|
||||
ret[17] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[9];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[18] = op[0];
|
||||
ret[19] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[10];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[20] = op[0];
|
||||
ret[21] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[11];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[22] = op[0];
|
||||
ret[23] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[12];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[24] = op[0];
|
||||
ret[25] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[13];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[26] = op[0];
|
||||
ret[27] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[14];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[28] = op[0];
|
||||
ret[29] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[15];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[30] = op[0];
|
||||
ret[31] = op[1];
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
ret[2 * idx] = op[1];
|
||||
ret[2 * idx + 1] = op[0];
|
||||
});
|
||||
|
||||
return ret;
|
||||
#else
|
||||
@@ -1120,106 +1026,18 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{bit_cast<__uint128_t>(x)};
|
||||
// TODO: pack in a loop
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
float_values.float_array[2 * idx] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
|
||||
Number<0>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
|
||||
Number<1>{}));
|
||||
});
|
||||
|
||||
return float_values.float32_array;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user