mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix
This commit is contained in:
@@ -101,8 +101,8 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
#if 0
|
||||
printf("printf a_host: \n");
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
@@ -141,9 +141,9 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
printf("\n");
|
||||
|
||||
printf("printf scale_b: \n");
|
||||
for(int n = 0; n < N / DequantGranularityN; n++)
|
||||
for(int n = 0; n < N / ScaleGranularityN; n++)
|
||||
{
|
||||
for(int k = 0; k < K / DequantGranularityK; k++)
|
||||
for(int k = 0; k < K / ScaleGranularityK; k++)
|
||||
{
|
||||
printf("%.2f ", ck_tile::type_convert<float>(scale_b(k, n)));
|
||||
}
|
||||
@@ -172,13 +172,24 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("printf scale_a: \n");
|
||||
for(int m = 0; m < M / ScaleGranularityM; m++)
|
||||
{
|
||||
for(int k = 0; k < K / ScaleGranularityK;)
|
||||
{
|
||||
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&scale_a(m, k).data)));
|
||||
k += 4;
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("printf scale_b: \n");
|
||||
for(int n = 0; n < N / DequantGranularityN; n++)
|
||||
for(int n = 0; n < N / ScaleGranularityN; n++)
|
||||
{
|
||||
for(int k = 0; k < K / DequantGranularityK;)
|
||||
for(int k = 0; k < K / ScaleGranularityK;)
|
||||
{
|
||||
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&scale_b(k, n).data)));
|
||||
k += 4;
|
||||
|
||||
@@ -226,7 +226,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -235,7 +235,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float sca
|
||||
return pk_fp4_t::pack(res, res);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -244,7 +244,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float sca
|
||||
return pk_fp4_t::pack(res, res);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -253,7 +253,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca
|
||||
return pk_fp4_t::pack(res, res);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -261,7 +261,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
|
||||
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -269,7 +269,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
|
||||
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -278,27 +278,27 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user