This commit is contained in:
mtgu0705
2025-09-15 20:50:26 -05:00
parent 3893c06540
commit a333206929
2 changed files with 28 additions and 17 deletions

View File

@@ -101,8 +101,8 @@ int run_mx_flatmm_with_layouts(int argc,
ck_tile::FillUniformDistribution<ScaleDataType>{-2.f, 2.f}(scale_b);
}
#if 0
#if 1
#if 0
printf("printf a_host: \n");
for(int m = 0; m < M; m++)
{
@@ -141,9 +141,9 @@ int run_mx_flatmm_with_layouts(int argc,
printf("\n");
printf("printf scale_b: \n");
for(int n = 0; n < N / DequantGranularityN; n++)
for(int n = 0; n < N / ScaleGranularityN; n++)
{
for(int k = 0; k < K / DequantGranularityK; k++)
for(int k = 0; k < K / ScaleGranularityK; k++)
{
printf("%.2f ", ck_tile::type_convert<float>(scale_b(k, n)));
}
@@ -172,13 +172,24 @@ int run_mx_flatmm_with_layouts(int argc,
}
printf("\n");
}
printf("\n");
printf("printf scale_a: \n");
for(int m = 0; m < M / ScaleGranularityM; m++)
{
for(int k = 0; k < K / ScaleGranularityK;)
{
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&scale_a(m, k).data)));
k += 4;
}
printf("\n");
}
printf("\n");
printf("printf scale_b: \n");
for(int n = 0; n < N / DequantGranularityN; n++)
for(int n = 0; n < N / ScaleGranularityN; n++)
{
for(int k = 0; k < K / DequantGranularityK;)
for(int k = 0; k < K / ScaleGranularityK;)
{
printf("0x%08x ", *(reinterpret_cast<uint32_t*>(&scale_b(k, n).data)));
k += 4;

View File

@@ -226,7 +226,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale)
return convert_to_type<pk_fp4_t>(x, scale);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -235,7 +235,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float sca
return pk_fp4_t::pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -244,7 +244,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float sca
return pk_fp4_t::pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -253,7 +253,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca
return pk_fp4_t::pack(res, res);
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -261,7 +261,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -269,7 +269,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
return pk_fp4_t::pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale));
#endif
}
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
{
#if CK_TILE_FP4_CVT_DEVICE
return impl::_to_f4(x, scale);
@@ -278,27 +278,27 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
#endif
}
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
{
return x.to_fp32x2(scale);
}
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
{
return x.to_fp16x2(scale);
}
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
{
return x.to_bf16x2(scale);
}
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
{
return x.to_float(scale);
}
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
{
return x.to_fp16(scale);
}
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale = 1.0f)
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
{
return x.to_bf16(scale);
}