mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)
[CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
3af1a0aafc
commit
4c626aeaa6
@@ -359,6 +359,260 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x8_t bf8x8_to_bf16x8_scale(const bf8x8_t& src, const float& scale)
|
||||
{
|
||||
bf16x8_t y;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t USE_BOTTOM = 0;
|
||||
constexpr index_t USE_TOP = 1;
|
||||
|
||||
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
|
||||
union
|
||||
{
|
||||
uint32_t packed;
|
||||
bf8_t elements[4];
|
||||
} input;
|
||||
|
||||
union
|
||||
{
|
||||
bf16x2_t vec;
|
||||
bf16_t elements[2];
|
||||
} output;
|
||||
|
||||
input.elements[0] = src[src_offset];
|
||||
input.elements[1] = src[src_offset + 1];
|
||||
input.elements[2] = src[src_offset + 2];
|
||||
input.elements[3] = src[src_offset + 3];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_BOTTOM);
|
||||
y[dst_offset] = output.elements[0];
|
||||
y[dst_offset + 1] = output.elements[1];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_TOP);
|
||||
y[dst_offset + 2] = output.elements[0];
|
||||
y[dst_offset + 3] = output.elements[1];
|
||||
};
|
||||
|
||||
convert_quartet(0, 0);
|
||||
convert_quartet(4, 4);
|
||||
#else
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
y[i.value] = type_convert<bf16_t>(type_convert<float>(src[i.value]) * scale);
|
||||
});
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x8_t fp8x8_to_bf16x8_scale(const fp8x8_t& src, const float& scale)
|
||||
{
|
||||
bf16x8_t y;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t USE_BOTTOM = 0;
|
||||
constexpr index_t USE_TOP = 1;
|
||||
|
||||
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
|
||||
union
|
||||
{
|
||||
uint32_t packed;
|
||||
fp8_t elements[4];
|
||||
} input;
|
||||
|
||||
union
|
||||
{
|
||||
bf16x2_t vec;
|
||||
bf16_t elements[2];
|
||||
} output;
|
||||
|
||||
input.elements[0] = src[src_offset];
|
||||
input.elements[1] = src[src_offset + 1];
|
||||
input.elements[2] = src[src_offset + 2];
|
||||
input.elements[3] = src[src_offset + 3];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_BOTTOM);
|
||||
y[dst_offset] = output.elements[0];
|
||||
y[dst_offset + 1] = output.elements[1];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_TOP);
|
||||
y[dst_offset + 2] = output.elements[0];
|
||||
y[dst_offset + 3] = output.elements[1];
|
||||
};
|
||||
|
||||
convert_quartet(0, 0);
|
||||
convert_quartet(4, 4);
|
||||
#else
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
y[i.value] = type_convert<bf16_t>(type_convert<float>(src[i.value]) * scale);
|
||||
});
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const float& scale)
|
||||
{
|
||||
fp16x8_t y;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t USE_BOTTOM = 0;
|
||||
constexpr index_t USE_TOP = 1;
|
||||
|
||||
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
|
||||
union
|
||||
{
|
||||
uint32_t packed;
|
||||
fp8_t elements[4];
|
||||
} input;
|
||||
|
||||
union
|
||||
{
|
||||
fp16x2_t vec;
|
||||
fp16_t elements[2];
|
||||
} output;
|
||||
|
||||
input.elements[0] = src[src_offset];
|
||||
input.elements[1] = src[src_offset + 1];
|
||||
input.elements[2] = src[src_offset + 2];
|
||||
input.elements[3] = src[src_offset + 3];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_BOTTOM);
|
||||
y[dst_offset] = output.elements[0];
|
||||
y[dst_offset + 1] = output.elements[1];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_TOP);
|
||||
y[dst_offset + 2] = output.elements[0];
|
||||
y[dst_offset + 3] = output.elements[1];
|
||||
};
|
||||
|
||||
convert_quartet(0, 0);
|
||||
convert_quartet(4, 4);
|
||||
#else
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
y[i.value] = type_convert<fp16_t>(type_convert<float>(src[i.value]) * scale);
|
||||
});
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x8_t bf8x8_to_fp16x8_scale(const bf8x8_t& src, const float& scale)
|
||||
{
|
||||
fp16x8_t y;
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t USE_BOTTOM = 0;
|
||||
constexpr index_t USE_TOP = 1;
|
||||
|
||||
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
|
||||
union
|
||||
{
|
||||
uint32_t packed;
|
||||
bf8_t elements[4];
|
||||
} input;
|
||||
|
||||
union
|
||||
{
|
||||
fp16x2_t vec;
|
||||
fp16_t elements[2];
|
||||
} output;
|
||||
|
||||
input.elements[0] = src[src_offset];
|
||||
input.elements[1] = src[src_offset + 1];
|
||||
input.elements[2] = src[src_offset + 2];
|
||||
input.elements[3] = src[src_offset + 3];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_BOTTOM);
|
||||
y[dst_offset] = output.elements[0];
|
||||
y[dst_offset + 1] = output.elements[1];
|
||||
|
||||
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_TOP);
|
||||
y[dst_offset + 2] = output.elements[0];
|
||||
y[dst_offset + 3] = output.elements[1];
|
||||
};
|
||||
|
||||
convert_quartet(0, 0);
|
||||
convert_quartet(4, 4);
|
||||
#else
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
y[i.value] = type_convert<fp16_t>(type_convert<float>(src[i.value]) * scale);
|
||||
});
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE bf16x8_t fp4x4_to_bf16x8_scale(const pk_fp4x4_t& src, const float& scale)
|
||||
{
|
||||
bf16x8_t y;
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint32_t u32;
|
||||
pk_fp4x4_t pf4;
|
||||
} cvt;
|
||||
|
||||
constexpr index_t USE_BYTE_0 = 0;
|
||||
constexpr index_t USE_BYTE_1 = 1;
|
||||
constexpr index_t USE_BYTE_2 = 2;
|
||||
constexpr index_t USE_BYTE_3 = 3;
|
||||
|
||||
cvt.pf4 = src;
|
||||
bf16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_0);
|
||||
bf16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_1);
|
||||
bf16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_2);
|
||||
bf16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_3);
|
||||
|
||||
y[0] = y0[0];
|
||||
y[1] = y0[1];
|
||||
y[2] = y1[0];
|
||||
y[3] = y1[1];
|
||||
y[4] = y2[0];
|
||||
y[5] = y2[1];
|
||||
y[6] = y3[0];
|
||||
y[7] = y3[1];
|
||||
#else
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
auto yi = pk_fp4_to_bf16x2(src[i.value], scale);
|
||||
y[2 * i.value] = yi[0];
|
||||
y[2 * i.value + 1] = yi[1];
|
||||
});
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE fp16x8_t fp4x4_to_fp16x8_scale(const pk_fp4x4_t& src, const float& scale)
|
||||
{
|
||||
fp16x8_t y;
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint32_t u32;
|
||||
pk_fp4x4_t pf4;
|
||||
} cvt;
|
||||
|
||||
constexpr index_t USE_BYTE_0 = 0;
|
||||
constexpr index_t USE_BYTE_1 = 1;
|
||||
constexpr index_t USE_BYTE_2 = 2;
|
||||
constexpr index_t USE_BYTE_3 = 3;
|
||||
|
||||
cvt.pf4 = src;
|
||||
fp16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_0);
|
||||
fp16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_1);
|
||||
fp16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_2);
|
||||
fp16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_3);
|
||||
|
||||
y[0] = y0[0];
|
||||
y[1] = y0[1];
|
||||
y[2] = y1[0];
|
||||
y[3] = y1[1];
|
||||
y[4] = y2[0];
|
||||
y[5] = y2[1];
|
||||
y[6] = y3[0];
|
||||
y[7] = y3[1];
|
||||
#else
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
auto yi = pk_fp4_to_fp16x2(src[i.value], scale);
|
||||
y[2 * i.value] = yi[0];
|
||||
y[2 * i.value + 1] = yi[1];
|
||||
});
|
||||
#endif
|
||||
return y;
|
||||
}
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack8";
|
||||
@@ -437,6 +691,50 @@ struct DequantPack8
|
||||
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(bf16x8_t& y, const pk_fp4x4_t& x, const float& z) const
|
||||
{
|
||||
y = fp4x4_to_bf16x8_scale(x, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(fp16x8_t& y, const pk_fp4x4_t& x, const float& z) const
|
||||
{
|
||||
y = fp4x4_to_fp16x8_scale(x, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(bf16x8_t& y, const bf8x8_t& x, const float& z) const
|
||||
{
|
||||
y = bf8x8_to_bf16x8_scale(x, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(bf16x8_t& y, const fp8x8_t& x, const float& z) const
|
||||
{
|
||||
y = fp8x8_to_bf16x8_scale(x, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(fp16x8_t& y, const fp8x8_t& x, const float& z) const
|
||||
{
|
||||
y = fp8x8_to_fp16x8_scale(x, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(fp16x8_t& y, const bf8x8_t& x, const float& z) const
|
||||
{
|
||||
y = bf8x8_to_fp16x8_scale(x, z);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
operator()(bf16x8_t& y, const bf16x8_t& x, const float& z) const
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
y[i.value] = type_convert<bf16_t>(type_convert<float>(x[i.value]) * z);
|
||||
});
|
||||
}
|
||||
|
||||
constexpr const static bool is_pack8_invocable = true;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user