mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Pack one more conversion in a loop
This commit is contained in:
@@ -879,129 +879,23 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
|
||||
__uint128_t bitwise;
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{0}, tmp_values{0};
|
||||
} f4_values{0};
|
||||
union
|
||||
{
|
||||
float2_t floatx2_array[16];
|
||||
float32_t floatx32_array;
|
||||
} float_values{{0}};
|
||||
float_values.floatx32_array = x;
|
||||
// TODO: pack in a loop
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[0][1], float_values.floatx2_array[0][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[1][1], float_values.floatx2_array[1][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[2][1], float_values.floatx2_array[2][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[3][1], float_values.floatx2_array[3][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[4][1], float_values.floatx2_array[4][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[5][1], float_values.floatx2_array[5][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[6][1], float_values.floatx2_array[6][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[7][1], float_values.floatx2_array[7][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[8][1], float_values.floatx2_array[8][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[9][1], float_values.floatx2_array[9][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[10][1], float_values.floatx2_array[10][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[11][1], float_values.floatx2_array[11][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[12][1], float_values.floatx2_array[12][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[13][1], float_values.floatx2_array[13][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[14][1], float_values.floatx2_array[14][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[15][1], float_values.floatx2_array[15][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
f4_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
});
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#else
|
||||
@@ -1011,106 +905,14 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{0};
|
||||
// TODO: pack in a loop
|
||||
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
f4_t tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
ck::static_for<0, 32, 1>{}([&](auto idx) {
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[static_cast<int>(idx)] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
});
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user