diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 7785357507..b1f6f0d15c 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -879,129 +879,23 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f __uint128_t bitwise; f4x2_t f4x2_array[16]; f4x32_t f4x32_array; - } f4_values{0}, tmp_values{0}; + } f4_values{0}; union { float2_t floatx2_array[16]; float32_t floatx32_array; } float_values{{0}}; float_values.floatx32_array = x; - // TODO: pack in a loop - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[0][1], float_values.floatx2_array[0][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[1][1], float_values.floatx2_array[1][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[2][1], float_values.floatx2_array[2][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[3][1], float_values.floatx2_array[3][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[4][1], float_values.floatx2_array[4][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[5][1], float_values.floatx2_array[5][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[6][1], float_values.floatx2_array[6][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[7][1], float_values.floatx2_array[7][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[8][1], float_values.floatx2_array[8][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[9][1], float_values.floatx2_array[9][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[10][1], float_values.floatx2_array[10][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[11][1], float_values.floatx2_array[11][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[12][1], float_values.floatx2_array[12][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[13][1], float_values.floatx2_array[13][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[14][1], float_values.floatx2_array[14][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, - float2_t{float_values.floatx2_array[15][1], float_values.floatx2_array[15][0]}, - rng, - scale, - 0); - f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + // permute high bits and low bits to match the order of the original vector + f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + f4_values.bitwise, + float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]}, + rng, + scale, + 0); + }); return f4_values.f4x32_array; #else @@ -1011,106 +905,14 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{0}; - // TODO: pack in a loop - auto tmp = utils::sat_convert_to_type_sr(x[0] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[1] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[2] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[3] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[4] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[5] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[6] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[7] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[8] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[9] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[10] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[11] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[12] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[13] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[14] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[15] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; + f4_t tmp; - tmp = utils::sat_convert_to_type_sr(x[16] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[17] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[18] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[19] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[20] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[21] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[22] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[23] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type_sr(x[24] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[25] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[26] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[27] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[28] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[29] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[30] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[31] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; + ck::static_for<0, 32, 1>{}([&](auto idx) { + tmp = utils::sat_convert_to_type_sr(x[static_cast(idx)] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + }); return f4_values.f4x32_array; #endif