mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)
[CK_TILE] Extend support of mix precision microscaling BQuant (#4267) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes Supported types combinations using BQuant=e8m0: - A=bf16 - B=bf16,bf8,fp4 Summary: - remove usage of `pk_fp4_raw_t`: consistent with other implementations and avoid taking into account of the packed size explicitly. In general, the raw type should not be used because CK Tile internally takes care of the PackedSize, so using the raw type adds unnecessary complexity to the implementation - handle microscaling by checking for `e8m0` type for BQuant (previous implementation was inconsistent) - add support for scaling instructions in `DequantPack8` - mx pipeline: - extend existing pipeline to support different B types - add support to scale and cast before writing to LDS or after reading from LDS (this can be defined in the `Problem` by the user) - block gemm: - mx pipeline is now using block gemm BQuant - block gemm BQuant can now load from LDS and apply scale and then call block gemm universal operator. This adds new functionalities and remove code duplication - warp gemm: - add case to support 128bit ds_read/write for both A and B when A=16bit and B=8bit - add examples and tests: note that some tests for bf16/fp4 already existed but were removed during previous tests refactoring. I added them again and other relevant tests for new types combinations ## Checklist Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [ ] I have run `clang-format` on all changed files - [ ] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
This commit is contained in:
committed by
assistant-librarian[bot]
parent
3af1a0aafc
commit
4c626aeaa6
@@ -2865,6 +2865,12 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::pk_fp4_t>)
|
||||
{
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
|
||||
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr4_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "not implemented");
|
||||
|
||||
@@ -50,60 +50,61 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::valu
|
||||
template <typename DataType>
|
||||
struct DefaultTranspose
|
||||
{
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad16
|
||||
template <index_t LaneGroupSize, index_t NumBitType>
|
||||
struct Quad
|
||||
{
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<4>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
// The tile is defined by the LaneGroupSize, which defines the number of lanes in the M/N
|
||||
// dimensions for the MMA instruction defined by warp gemm.
|
||||
// The LaneGroupSize is subdivided into groups of 16 (finer granularity of MMA
|
||||
// instructions), we define these as major subtiles. Each of these major subtile is divided
|
||||
// into minor subtiles which group the lanes exchanging data during the transpose Example
|
||||
// LaneGroupSize = 16, 16 bit type:
|
||||
// - There is 1 group of 16 lanes (1 major subtile)
|
||||
// - Each major subtile is divided into 4 minor subtiles of (4x4) -> 4 lanes transpose
|
||||
// the minor subtile and each lane holds 4 elements
|
||||
|
||||
// all load transpose instructions use 64 bit right now
|
||||
static constexpr index_t InstructionBits = 64;
|
||||
// Subtile major dimension is fixed
|
||||
static constexpr index_t SubtileMajorDimension = 16;
|
||||
// Number of subtile major
|
||||
static constexpr index_t NumSubtilesMajor = LaneGroupSize / 16;
|
||||
// number of elements loaded by each lane with single instruction, but also number
|
||||
// of consecutive lanes in a subtile. Subtile is squared (NLanes x NElementsPerLane)
|
||||
static constexpr index_t SubtileMinorDimension = InstructionBits / NumBitType;
|
||||
// Number of subtiles minor inside each subtile major
|
||||
static constexpr index_t NumSubtilesMinor = 16 / SubtileMinorDimension;
|
||||
|
||||
using InputEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<SubtileMinorDimension>,
|
||||
sequence<NumSubtilesMajor, NumSubtilesMinor, SubtileMinorDimension>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<SubtileMinorDimension>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
struct Quad8
|
||||
{
|
||||
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
|
||||
"LaneGroupSize must be 64, 32, or 16");
|
||||
using InputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
|
||||
tuple<sequence<2, 1, 2>>,
|
||||
tuple<sequence<0, 0, 1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>;
|
||||
|
||||
using OutputEncoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<LaneGroupSize>, sequence<8>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>;
|
||||
};
|
||||
static constexpr index_t PackedSize = numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t NumBitsDataType = (sizeof(DataType) * 8) / PackedSize;
|
||||
|
||||
// Select based on data size
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16<LaneGroupSize>::InputEncoding,
|
||||
typename Quad8<LaneGroupSize>::InputEncoding>;
|
||||
using QuadInputEncoding = typename Quad<LaneGroupSize, NumBitsDataType>::InputEncoding;
|
||||
|
||||
template <index_t LaneGroupSize>
|
||||
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
|
||||
typename Quad16<LaneGroupSize>::OutputEncoding,
|
||||
typename Quad8<LaneGroupSize>::OutputEncoding>;
|
||||
using QuadOutputEncoding = typename Quad<LaneGroupSize, NumBitsDataType>::OutputEncoding;
|
||||
|
||||
// Always swap last two dimensions
|
||||
static constexpr auto transpose_dims = sequence<1, 0>{};
|
||||
|
||||
@@ -78,7 +78,7 @@ struct static_distributed_tensor
|
||||
constexpr auto sliced_thread_tensor_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
|
||||
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
|
||||
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size() / PackedSize>
|
||||
sliced_thread_data;
|
||||
|
||||
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
|
||||
|
||||
@@ -287,8 +287,8 @@ struct tensor_view
|
||||
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
@@ -303,7 +303,8 @@ struct tensor_view
|
||||
bool is_valid_element // flag
|
||||
) const
|
||||
{
|
||||
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
|
||||
return buf_.template transpose_get<X>(
|
||||
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element);
|
||||
}
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
|
||||
@@ -736,7 +736,7 @@ struct tile_window_with_static_distribution
|
||||
.template get_transpose_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, offset);
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto orig_idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
@@ -747,10 +747,12 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
|
||||
|
||||
constexpr index_t linear_distributed_index =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
|
||||
vec_value.template get_as<typename Base::DataType>()[j];
|
||||
vec_value
|
||||
.template get_as<typename Base::DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
|
||||
Reference in New Issue
Block a user