From d19cebbe346f8e47dbcc964c98f02e32169d5b3f Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Tue, 18 Feb 2025 20:03:43 +0000 Subject: [PATCH] Permute conversion args --- include/ck/utility/type_convert.hpp | 96 ++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 907e85804d..9215db6f88 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1025,55 +1025,119 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f float_values.floatx32_array = x; // TODO: pack in a loop tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[0][1], float_values.floatx2_array[0][0]}, + rng, + scale, + 0); f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[1][1], float_values.floatx2_array[1][0]}, + rng, + scale, + 0); f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[2][1], float_values.floatx2_array[2][0]}, + rng, + scale, + 0); f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[3][1], float_values.floatx2_array[3][0]}, + rng, + scale, + 0); f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[4][1], float_values.floatx2_array[4][0]}, + rng, + scale, + 0); f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[5][1], float_values.floatx2_array[5][0]}, + rng, + scale, + 0); f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[6][1], float_values.floatx2_array[6][0]}, + rng, + scale, + 0); f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[7][1], float_values.floatx2_array[7][0]}, + rng, + scale, + 0); f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[8][1], float_values.floatx2_array[8][0]}, + rng, + scale, + 0); f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[9][1], float_values.floatx2_array[9][0]}, + rng, + scale, + 0); f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[10][1], float_values.floatx2_array[10][0]}, + rng, + scale, + 0); f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[11][1], float_values.floatx2_array[11][0]}, + rng, + scale, + 0); f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[12][1], float_values.floatx2_array[12][0]}, + rng, + scale, + 0); f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[13][1], float_values.floatx2_array[13][0]}, + rng, + scale, + 0); f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[14][1], float_values.floatx2_array[14][0]}, + rng, + scale, + 0); f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0); + tmp_values.bitwise, + float2_t{float_values.floatx2_array[15][1], float_values.floatx2_array[15][0]}, + rng, + scale, + 0); f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; return f4_values.f4x32_array;