[rocm-libraries] ROCm/rocm-libraries#4267 (commit 3c5d95e)

[CK_TILE] Extend support of mix precision microscaling BQuant
 (#4267)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

Supported types combinations using BQuant=e8m0:
 - A=bf16
 - B=bf16,bf8,fp4

Summary:
- remove usage of `pk_fp4_raw_t`: consistent with other implementations
and avoid taking into account of the packed size explicitly. In general,
the raw type should not be used because CK Tile internally takes care of
the PackedSize, so using the raw type adds unnecessary complexity to the
implementation
- handle microscaling by checking for `e8m0` type for BQuant (previous
implementation was inconsistent)
 - add support for scaling instructions in `DequantPack8`
 - mx pipeline:
   - extend existing pipeline to support different B types
- add support to scale and cast before writing to LDS or after reading
from LDS (this can be defined in the `Problem` by the user)
 - block gemm:
   - mx pipeline is now using block gemm BQuant
- block gemm BQuant can now load from LDS and apply scale and then call
block gemm universal operator. This adds new functionalities and remove
code duplication
 - warp gemm:
- add case to support 128bit ds_read/write for both A and B when A=16bit
and B=8bit
- add examples and tests: note that some tests for bf16/fp4 already
existed but were removed during previous tests refactoring. I added them
again and other relevant tests for new types combinations

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered
This commit is contained in:
Enrico Degregori
2026-02-24 17:57:02 +00:00
committed by assistant-librarian[bot]
parent 3af1a0aafc
commit 4c626aeaa6
44 changed files with 2061 additions and 683 deletions

View File

@@ -2865,6 +2865,12 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::pk_fp4_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr4_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");

View File

@@ -50,60 +50,61 @@ constexpr bool is_sequence_suffix_v = is_sequence_suffix<Suffix, Sequence>::valu
template <typename DataType>
struct DefaultTranspose
{
template <index_t LaneGroupSize>
struct Quad16
template <index_t LaneGroupSize, index_t NumBitType>
struct Quad
{
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<4>, sequence<LaneGroupSize / 16, 4, 4>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<4>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
// The tile is defined by the LaneGroupSize, which defines the number of lanes in the M/N
// dimensions for the MMA instruction defined by warp gemm.
// The LaneGroupSize is subdivided into groups of 16 (finer granularity of MMA
// instructions), we define these as major subtiles. Each of these major subtile is divided
// into minor subtiles which group the lanes exchanging data during the transpose Example
// LaneGroupSize = 16, 16 bit type:
// - There is 1 group of 16 lanes (1 major subtile)
// - Each major subtile is divided into 4 minor subtiles of (4x4) -> 4 lanes transpose
// the minor subtile and each lane holds 4 elements
// all load transpose instructions use 64 bit right now
static constexpr index_t InstructionBits = 64;
// Subtile major dimension is fixed
static constexpr index_t SubtileMajorDimension = 16;
// Number of subtile major
static constexpr index_t NumSubtilesMajor = LaneGroupSize / 16;
// number of elements loaded by each lane with single instruction, but also number
// of consecutive lanes in a subtile. Subtile is squared (NLanes x NElementsPerLane)
static constexpr index_t SubtileMinorDimension = InstructionBits / NumBitType;
// Number of subtiles minor inside each subtile major
static constexpr index_t NumSubtilesMinor = 16 / SubtileMinorDimension;
using InputEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<SubtileMinorDimension>,
sequence<NumSubtilesMajor, NumSubtilesMinor, SubtileMinorDimension>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<LaneGroupSize>, sequence<SubtileMinorDimension>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
template <index_t LaneGroupSize>
struct Quad8
{
static_assert(LaneGroupSize == 64 || LaneGroupSize == 32 || LaneGroupSize == 16,
"LaneGroupSize must be 64, 32, or 16");
using InputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<8>, sequence<LaneGroupSize / 16, 2, 8>>,
tuple<sequence<2, 1, 2>>,
tuple<sequence<0, 0, 1>>,
sequence<2>,
sequence<2>>;
using OutputEncoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<LaneGroupSize>, sequence<8>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>;
};
static constexpr index_t PackedSize = numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t NumBitsDataType = (sizeof(DataType) * 8) / PackedSize;
// Select based on data size
template <index_t LaneGroupSize>
using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16<LaneGroupSize>::InputEncoding,
typename Quad8<LaneGroupSize>::InputEncoding>;
using QuadInputEncoding = typename Quad<LaneGroupSize, NumBitsDataType>::InputEncoding;
template <index_t LaneGroupSize>
using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
typename Quad16<LaneGroupSize>::OutputEncoding,
typename Quad8<LaneGroupSize>::OutputEncoding>;
using QuadOutputEncoding = typename Quad<LaneGroupSize, NumBitsDataType>::OutputEncoding;
// Always swap last two dimensions
static constexpr auto transpose_dims = sequence<1, 0>{};

View File

@@ -78,7 +78,7 @@ struct static_distributed_tensor
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size() / PackedSize>
sliced_thread_data;
static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {

View File

@@ -287,8 +287,8 @@ struct tensor_view
get_transpose_vectorized_elements(const TensorCoord& coord, index_t linear_offset) const
{
return buf_.template transpose_get<X>(
coord.get_offset(),
linear_offset,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
@@ -303,7 +303,8 @@ struct tensor_view
bool is_valid_element // flag
) const
{
return buf_.template transpose_get<X>(coord.get_offset(), linear_offset, is_valid_element);
return buf_.template transpose_get<X>(
coord.get_offset() / PackedSize, linear_offset / PackedSize, is_valid_element);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X

View File

@@ -736,7 +736,7 @@ struct tile_window_with_static_distribution
.template get_transpose_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, offset);
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto orig_idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
@@ -747,10 +747,12 @@ struct tile_window_with_static_distribution
constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
constexpr index_t linear_distributed_index =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
vec_value.template get_as<typename Base::DataType>()[j];
vec_value
.template get_as<typename Base::DataType>()[j / Traits::PackedSize];
});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))

View File

@@ -388,49 +388,56 @@ template <typename ADataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
typename BLayout,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
AccDataType pasual = 0;
for(std::size_t k = 0; k < (K / 2); k++)
{
using ComputeType = float;
auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
ComputeType v_a_0, v_a_1;
ComputeType v_b_0, v_b_1;
AccDataType v_acc = 0;
using ComputeType = float;
ComputeType v_a;
ComputeType v_b;
v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
auto load_b = [&](std::size_t k) -> AccDataType {
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
auto b_scale_fp4 = type_convert<float>(std::pow(2.0f, b_scale));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
const auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return (n & 1) ? type_convert<ComputeType>(b_pack.unpack(number<1>{}))
: type_convert<ComputeType>(b_pack.unpack(number<0>{}));
}
else
{
return (k & 1) ? type_convert<ComputeType>(b_pack.unpack(number<1>{}))
: type_convert<ComputeType>(b_pack.unpack(number<0>{}));
}
}
else
{
return ck_tile::type_convert<ComputeType>(b_element_op(b_k_n(k, n)));
}
};
pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
v_acc += pasual;
for(std::size_t k = 0; k < K; k++)
{
const auto b_scale = type_convert<float>(q(k / QuantGroupSize::kK, n));
v_a = ck_tile::type_convert<ComputeType>(a_element_op(a_m_k(m, k)));
v_b = load_b(k) * b_scale;
v_acc += v_a * v_b;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};

View File

@@ -24,6 +24,7 @@ template <> struct DataTypeTraits<pk_int4_t> { static constexpr const char * nam
template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
template <> struct DataTypeTraits<pk_fp6x16_t> { static constexpr const char * name = "pk_fp6x16"; };
template <> struct DataTypeTraits<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };
template <> struct DataTypeTraits<e8m0_t> { static constexpr const char * name = "e8m0"; };
template <memory_operation_enum MemOp> struct memOpToStr;
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };

View File

@@ -359,6 +359,260 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
}
#endif
CK_TILE_HOST_DEVICE bf16x8_t bf8x8_to_bf16x8_scale(const bf8x8_t& src, const float& scale)
{
bf16x8_t y;
#if defined(__gfx950__)
constexpr index_t USE_BOTTOM = 0;
constexpr index_t USE_TOP = 1;
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
union
{
uint32_t packed;
bf8_t elements[4];
} input;
union
{
bf16x2_t vec;
bf16_t elements[2];
} output;
input.elements[0] = src[src_offset];
input.elements[1] = src[src_offset + 1];
input.elements[2] = src[src_offset + 2];
input.elements[3] = src[src_offset + 3];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_BOTTOM);
y[dst_offset] = output.elements[0];
y[dst_offset + 1] = output.elements[1];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.packed, scale, USE_TOP);
y[dst_offset + 2] = output.elements[0];
y[dst_offset + 3] = output.elements[1];
};
convert_quartet(0, 0);
convert_quartet(4, 4);
#else
static_for<0, 8, 1>{}([&](auto i) {
y[i.value] = type_convert<bf16_t>(type_convert<float>(src[i.value]) * scale);
});
#endif
return y;
}
CK_TILE_HOST_DEVICE bf16x8_t fp8x8_to_bf16x8_scale(const fp8x8_t& src, const float& scale)
{
bf16x8_t y;
#if defined(__gfx950__)
constexpr index_t USE_BOTTOM = 0;
constexpr index_t USE_TOP = 1;
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
union
{
uint32_t packed;
fp8_t elements[4];
} input;
union
{
bf16x2_t vec;
bf16_t elements[2];
} output;
input.elements[0] = src[src_offset];
input.elements[1] = src[src_offset + 1];
input.elements[2] = src[src_offset + 2];
input.elements[3] = src[src_offset + 3];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_BOTTOM);
y[dst_offset] = output.elements[0];
y[dst_offset + 1] = output.elements[1];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.packed, scale, USE_TOP);
y[dst_offset + 2] = output.elements[0];
y[dst_offset + 3] = output.elements[1];
};
convert_quartet(0, 0);
convert_quartet(4, 4);
#else
static_for<0, 8, 1>{}([&](auto i) {
y[i.value] = type_convert<bf16_t>(type_convert<float>(src[i.value]) * scale);
});
#endif
return y;
}
CK_TILE_HOST_DEVICE fp16x8_t fp8x8_to_fp16x8_scale(const fp8x8_t& src, const float& scale)
{
fp16x8_t y;
#if defined(__gfx950__)
constexpr index_t USE_BOTTOM = 0;
constexpr index_t USE_TOP = 1;
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
union
{
uint32_t packed;
fp8_t elements[4];
} input;
union
{
fp16x2_t vec;
fp16_t elements[2];
} output;
input.elements[0] = src[src_offset];
input.elements[1] = src[src_offset + 1];
input.elements[2] = src[src_offset + 2];
input.elements[3] = src[src_offset + 3];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_BOTTOM);
y[dst_offset] = output.elements[0];
y[dst_offset + 1] = output.elements[1];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.packed, scale, USE_TOP);
y[dst_offset + 2] = output.elements[0];
y[dst_offset + 3] = output.elements[1];
};
convert_quartet(0, 0);
convert_quartet(4, 4);
#else
static_for<0, 8, 1>{}([&](auto i) {
y[i.value] = type_convert<fp16_t>(type_convert<float>(src[i.value]) * scale);
});
#endif
return y;
}
CK_TILE_HOST_DEVICE fp16x8_t bf8x8_to_fp16x8_scale(const bf8x8_t& src, const float& scale)
{
fp16x8_t y;
#if defined(__gfx950__)
constexpr index_t USE_BOTTOM = 0;
constexpr index_t USE_TOP = 1;
auto convert_quartet = [&](index_t src_offset, index_t dst_offset) {
union
{
uint32_t packed;
bf8_t elements[4];
} input;
union
{
fp16x2_t vec;
fp16_t elements[2];
} output;
input.elements[0] = src[src_offset];
input.elements[1] = src[src_offset + 1];
input.elements[2] = src[src_offset + 2];
input.elements[3] = src[src_offset + 3];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_BOTTOM);
y[dst_offset] = output.elements[0];
y[dst_offset + 1] = output.elements[1];
output.vec = __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(input.packed, scale, USE_TOP);
y[dst_offset + 2] = output.elements[0];
y[dst_offset + 3] = output.elements[1];
};
convert_quartet(0, 0);
convert_quartet(4, 4);
#else
static_for<0, 8, 1>{}([&](auto i) {
y[i.value] = type_convert<fp16_t>(type_convert<float>(src[i.value]) * scale);
});
#endif
return y;
}
CK_TILE_HOST_DEVICE bf16x8_t fp4x4_to_bf16x8_scale(const pk_fp4x4_t& src, const float& scale)
{
bf16x8_t y;
#if defined(__gfx950__)
union
{
uint32_t u32;
pk_fp4x4_t pf4;
} cvt;
constexpr index_t USE_BYTE_0 = 0;
constexpr index_t USE_BYTE_1 = 1;
constexpr index_t USE_BYTE_2 = 2;
constexpr index_t USE_BYTE_3 = 3;
cvt.pf4 = src;
bf16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_0);
bf16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_1);
bf16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_2);
bf16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, USE_BYTE_3);
y[0] = y0[0];
y[1] = y0[1];
y[2] = y1[0];
y[3] = y1[1];
y[4] = y2[0];
y[5] = y2[1];
y[6] = y3[0];
y[7] = y3[1];
#else
static_for<0, 4, 1>{}([&](auto i) {
auto yi = pk_fp4_to_bf16x2(src[i.value], scale);
y[2 * i.value] = yi[0];
y[2 * i.value + 1] = yi[1];
});
#endif
return y;
}
CK_TILE_HOST_DEVICE fp16x8_t fp4x4_to_fp16x8_scale(const pk_fp4x4_t& src, const float& scale)
{
fp16x8_t y;
#if defined(__gfx950__)
union
{
uint32_t u32;
pk_fp4x4_t pf4;
} cvt;
constexpr index_t USE_BYTE_0 = 0;
constexpr index_t USE_BYTE_1 = 1;
constexpr index_t USE_BYTE_2 = 2;
constexpr index_t USE_BYTE_3 = 3;
cvt.pf4 = src;
fp16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_0);
fp16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_1);
fp16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_2);
fp16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(cvt.u32, scale, USE_BYTE_3);
y[0] = y0[0];
y[1] = y0[1];
y[2] = y1[0];
y[3] = y1[1];
y[4] = y2[0];
y[5] = y2[1];
y[6] = y3[0];
y[7] = y3[1];
#else
static_for<0, 4, 1>{}([&](auto i) {
auto yi = pk_fp4_to_fp16x2(src[i.value], scale);
y[2 * i.value] = yi[0];
y[2 * i.value + 1] = yi[1];
});
#endif
return y;
}
struct PassThroughPack8
{
static constexpr const char* name = "PassThroughPack8";
@@ -437,6 +691,50 @@ struct DequantPack8
y.hi = i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(bf16x8_t& y, const pk_fp4x4_t& x, const float& z) const
{
y = fp4x4_to_bf16x8_scale(x, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const pk_fp4x4_t& x, const float& z) const
{
y = fp4x4_to_fp16x8_scale(x, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(bf16x8_t& y, const bf8x8_t& x, const float& z) const
{
y = bf8x8_to_bf16x8_scale(x, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(bf16x8_t& y, const fp8x8_t& x, const float& z) const
{
y = fp8x8_to_bf16x8_scale(x, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const fp8x8_t& x, const float& z) const
{
y = fp8x8_to_fp16x8_scale(x, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(fp16x8_t& y, const bf8x8_t& x, const float& z) const
{
y = bf8x8_to_fp16x8_scale(x, z);
}
CK_TILE_HOST_DEVICE constexpr void
operator()(bf16x8_t& y, const bf16x8_t& x, const float& z) const
{
static_for<0, 8, 1>{}([&](auto i) {
y[i.value] = type_convert<bf16_t>(type_convert<float>(x[i.value]) * z);
});
}
constexpr const static bool is_pack8_invocable = true;
};

View File

@@ -99,7 +99,7 @@ struct CShuffleEpilogue
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;

View File

@@ -97,7 +97,8 @@ struct BlockUniversalGemmAsBsCr
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;

View File

@@ -20,8 +20,23 @@ struct GemmPipelineAgBgCrImplBase
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType =
std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>, ADataType, BInDataType>;
template <typename T>
using has_bcastpolicy_type = decltype(T::BCastPolicy);
static constexpr bool IsBCastPolicyBeforeLDSWrite = [] {
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
{
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
}
else
{
return false;
}
}();
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite, ADataType, BInDataType>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
@@ -226,6 +241,12 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
const ALdsLoadTileDistr&) const
{
// with pk_int4_t load transpose the LDS type is always BDataType
using ADataTypeLDS =
std::conditional_t<std::is_same_v<typename Problem::ADataType, pk_int4_t>,
typename Problem::BDataType,
typename Problem::ADataType>;
auto a_lds_shape = []() {
if constexpr(is_a_load_tr)
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
@@ -238,9 +259,8 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_load_tile_distr = []() {
if constexpr(is_a_load_tr)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename ALdsLoadTileDistr::DstrEncode,
typename Problem::ADataType>::TransposedDstrEncode{});
typename InputTileDistributionTraits<typename ALdsLoadTileDistr::DstrEncode,
ADataTypeLDS>::TransposedDstrEncode{});
else
return ALdsLoadTileDistr{};
}();
@@ -313,10 +333,9 @@ struct GemmPipelineAgBgCrImplBase
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
using BLdsDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)

View File

@@ -10,6 +10,12 @@
namespace ck_tile {
enum struct CastPolicy
{
BeforeLDSWrite,
AfterLDSRead,
};
enum struct GemmPipelineScheduler
{
Default,

View File

@@ -80,6 +80,21 @@ struct UniversalGemmBasePolicy
static constexpr bool is_b_load_tr = false;
#endif
template <typename T>
using has_bcastpolicy_type = decltype(T::BCastPolicy);
template <typename Problem>
static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] {
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
{
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
}
else
{
return false;
}
}();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
@@ -305,11 +320,11 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
@@ -589,15 +604,14 @@ struct UniversalGemmBasePolicy
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
{
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType = std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
if constexpr(Problem::FixedVectorSize)
{
@@ -739,13 +753,13 @@ struct UniversalGemmBasePolicy
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t KPerBlock = std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
? Problem::BlockGemmShape::kK / 2
: Problem::BlockGemmShape::kK;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// If we cast before writing to LDS, the vectorsize is defined by the A type
// since the assumption is that A type is going to be the B LDS type
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
constexpr index_t VecLoadSize =
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
? 4
IsBCastPolicyBeforeLDSWrite
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
@@ -855,10 +869,10 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
{
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
@@ -900,7 +914,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;

View File

@@ -185,16 +185,35 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAttributeMfmaItera
#endif
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
AttrNumAccess>>;
AttrNumAccessA,
AttrNumAccessB>>;
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32<WGAttrCtlEnum::Default_>,
2,
AttrNumAccessA,
AttrNumAccessB>>;
#else
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
AttrNumAccessA>>;
template <WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
4,
AttrNumAccessA,
AttrNumAccessB>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAttributeMfmaIterateK_SwizzleA<

View File

@@ -17,13 +17,47 @@ enum class WGAttrNumAccessEnum
Invalid = -1
};
template <WGAttrNumAccessEnum AttrNumAccess>
struct get_wgattr_num_access
{
private:
static constexpr index_t getAccesses()
{
if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Single)
{
return 1;
}
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Double)
{
return 2;
}
else if constexpr(AttrNumAccess == WGAttrNumAccessEnum::Quad)
{
return 4;
}
else
{
static_assert(false, "unsupported AttrNumAccess");
return 0;
}
}
public:
static constexpr auto value = getAccesses();
};
template <typename WarpGemmAttributeMfmaImpl_,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
struct WarpGemmAttributeMfma
{
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
@@ -44,12 +78,13 @@ struct WarpGemmAttributeMfma
static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
"Multi-block WarpGemmAttributeMfmaImpl is not supported");
template <index_t kMNLane>
template <index_t kMNLane, index_t AttrNumAccessV_>
static constexpr auto get_warp_dstr_encoding()
{
static_assert(kKPerThread % AttrNumAccessV == 0,
static_assert(kKPerThread % AttrNumAccessV_ == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV == 1)
if constexpr(AttrNumAccessV_ == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -57,18 +92,48 @@ struct WarpGemmAttributeMfma
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
{
// AttrNumAccess splits the kABKPerLane
// We can split them but still have them contiguous (packed) or have them interleaved.
// The reason to split the dimension but still have it packed is to match load transpose
// encoding when A and B use different AttrNumAccess (they have different types in LDS)
// Example
// A: 16bit, B: 8bit
// Load transpose B: lane0 -> K=0..7 (only 1 instruction)
// Load transpose A: lane0 -> K=0..3 first instruction, K=4..7 second instruction
// In this way the data in register are consistent between A and B
if constexpr(UsePackNumAccess)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<Impl::kABKLane,
AttrNumAccessV_,
Impl::kABKPerLane / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2, 2>,
sequence<1, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV_,
Impl::kABKLane,
Impl::kABKPerLane / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
}
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane, AttrNumAccessAV>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane, AttrNumAccessBV>());
using CWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
@@ -121,14 +186,19 @@ struct WarpGemmAttributeMfma
template <typename WarpGemmAttributeMfmaImpl_,
index_t kKIter,
WGAttrNumAccessEnum AttrNumAccess_ = WGAttrNumAccessEnum::Single>
WGAttrNumAccessEnum AttrNumAccessA_ = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB_ = AttrNumAccessA_>
struct WarpGemmAttributeMfmaIterateK
{
static_assert(kKIter > 0, "wrong!");
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccess = AttrNumAccess_;
static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
using Impl = remove_cvref_t<WarpGemmAttributeMfmaImpl_>;
static constexpr auto AttrNumAccessA = AttrNumAccessA_;
static constexpr auto AttrNumAccessAV = get_wgattr_num_access<AttrNumAccessA>::value;
static constexpr auto AttrNumAccessB = AttrNumAccessB_;
static constexpr auto AttrNumAccessBV = get_wgattr_num_access<AttrNumAccessB>::value;
static constexpr bool UsePackNumAccess = AttrNumAccessA != AttrNumAccessB;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;
@@ -151,14 +221,15 @@ struct WarpGemmAttributeMfmaIterateK
static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
"Multi-block on both M & N directions is not supported");
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock>
template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock, index_t AttrNumAccessV_>
CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
{
if constexpr(kMNBlock == 1 && kNMBlock == 1)
{
static_assert(kKPerThread % AttrNumAccessV == 0,
static_assert(kKPerThread % AttrNumAccessV_ == 0,
"kKPerThread must be divisible by NumAccess");
if constexpr(AttrNumAccessV == 1)
if constexpr(AttrNumAccessV_ == 1)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
@@ -166,21 +237,40 @@ struct WarpGemmAttributeMfmaIterateK
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>{};
}
else
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
{
if constexpr(UsePackNumAccess)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<Impl::kABKLane,
AttrNumAccessV_,
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2, 2>,
sequence<1, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<kMNLane>,
sequence<AttrNumAccessV_,
Impl::kABKLane,
Impl::kABKPerLane * kKIter / AttrNumAccessV_>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{};
}
}
}
else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
{
static_assert(AttrNumAccessV == 1,
static_assert(AttrNumAccessV_ == 1,
"Multiple access is not supported when using multi-block");
// each M/N blocks share the same data
return tile_distribution_encoding<
@@ -193,7 +283,7 @@ struct WarpGemmAttributeMfmaIterateK
}
else if constexpr(1 < kMNBlock && kNMBlock == 1)
{
static_assert(AttrNumAccessV == 1,
static_assert(AttrNumAccessV_ == 1,
"Multiple access is not supported when using multi-block");
// single block to multi-block thread mapping
return tile_distribution_encoding<
@@ -245,10 +335,14 @@ struct WarpGemmAttributeMfmaIterateK
}
}
using AWarpDstrEncoding =
decltype(get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock>());
using BWarpDstrEncoding =
decltype(get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock>());
using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane,
Impl::kAMBlock,
Impl::kBNBlock,
AttrNumAccessAV>());
using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane,
Impl::kBNBlock,
Impl::kAMBlock,
AttrNumAccessBV>());
using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding());
// c_vec += a_vec * b_vec

View File

@@ -24,9 +24,10 @@ template <typename AType,
index_t NPerWave,
index_t KPerWave,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = ESingle>
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccessA = ESingle,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
struct Dispatcher;
// clang-format off
@@ -78,6 +79,10 @@ template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble, ESingle>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad, ESingle> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad, ESingle>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false, false, false, EQuad> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<EQuad>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 64, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble>; };
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<EDouble>; };
@@ -166,9 +171,10 @@ template <typename AType,
index_t NPerWave,
index_t KPerWave,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
AType,
BType,
@@ -179,6 +185,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
TransposeC,
SwizzleA,
UseStructuredSparsity,
AttrNumAccess>::Type;
AttrNumAccessA,
AttrNumAccessB>::Type;
} // namespace ck_tile

View File

@@ -24,9 +24,9 @@
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp"

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/elementwise.hpp"
@@ -101,20 +102,33 @@ struct BQuantBlockUniversalGemmAsBsCr
// 2. bf8, bf8, fp32 -> f32
// 3. i4, fp8, (fp8/fp32) -> f32
// 4. i4, bf8, (fp8/fp32) -> f32
static_assert((std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>) &&
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, ck_tile::pk_int4_t>) &&
(std::is_same_v<BQDataType, float> ||
std::is_same_v<BQDataType, ck_tile::fp8_t> ||
std::is_same_v<BQDataType, ck_tile::bf8_t>) &&
(std::is_same_v<ComputeDataType, fp8_t> ||
std::is_same_v<ComputeDataType, bf8_t>) &&
std::is_same_v<CDataType, fp32_t>);
// 5. bf16, (bf16/bf8/fp8/fp4), e8m0 -> f32
// 6. fp16, (fp16/fp8/bf8/fp4), e8m0 -> f32
static_assert(
is_any_of<ADataType, fp8_t, bf8_t, bf16_t, fp16_t>::value &&
is_any_of<BDataType, fp8_t, bf8_t, pk_int4_t, bf16_t, pk_fp4_t, fp16_t>::value &&
is_any_of<BQDataType, float, fp8_t, bf8_t, e8m0_t>::value &&
is_any_of<ComputeDataType, fp8_t, bf8_t, bf16_t, fp16_t>::value &&
std::is_same_v<CDataType, fp32_t>);
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
template <typename T>
using has_bcastpolicy_type = decltype(T::BCastPolicy);
static constexpr bool IsBCastPolicyBeforeLDSWrite = [] {
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
{
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
}
else
{
return false;
}
}();
};
public:
@@ -127,9 +141,12 @@ struct BQuantBlockUniversalGemmAsBsCr
using CDataType = remove_cvref_t<typename Traits::CDataType>;
// BDataType gets converted from PkInt4 during loading
// OverrideBDataType is only used when BCastPolicy is CastBeforeLDSWrite for microscale.
// In that case we use ADataType
using OverrideBDataType = std::conditional_t<
std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>,
(std::is_same_v<BDataType, pk_int4_t> &&
std::is_same_v<typename Traits::BLayout, tensor_layout::gemm::RowMajor>) ||
Traits::IsBCastPolicyBeforeLDSWrite,
ADataType,
BDataType>;
@@ -176,57 +193,17 @@ struct BQuantBlockUniversalGemmAsBsCr
using I0 = number<0>;
using I1 = number<1>;
// Use gemm universal block distribution encoding instead of duplicating it
using BlockGemmBase = BlockUniversalGemmAsBsCr<Problem_, Policy_, UnaryOpSize_>;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, KIterSeq>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
return BlockGemmBase::MakeABlockDistributionEncode();
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, KIterSeq>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
return BlockGemmBase::MakeBBlockDistributionEncode();
}
private:
@@ -235,20 +212,24 @@ struct BQuantBlockUniversalGemmAsBsCr
{
};
using BlockGemmImplBase = typename BlockUniversalGemmAsBsCr<Problem_, Policy_, UnaryOpSize_>::
template BlockGemmImpl<GemmPipelineScheduler::Intrawave, Traits>;
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits> : public BlockGemmImplBase
{
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using BlockGemmImplBase::a_warp_tile_;
using BlockGemmImplBase::b_warp_tile_;
using BlockGemmImplBase::BLdsTileDistr;
// If we apply scale while reading from LDS, then we can use the operator() from
// BlockUniversalGemmAsBsCr
using BlockGemmImplBase::operator();
using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
BLdsTile b_warp_tile_;
// static distributed tensor with LDS type
using BTypeTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
BTypeTile b_warp_tile_lds_;
// Load from LDS (assumption is that the scale will be applied in the block gemm)
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
@@ -265,6 +246,107 @@ struct BQuantBlockUniversalGemmAsBsCr
b_warp_tile_, b_block_window);
}
// Load from LDS and scale (then the tile can directly be consumed in the block gemm)
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
typename BQRegBlockTile,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
const BQRegBlockTile& bq_block_tensor,
bool_constant<ALoadTranspose> = {},
bool_constant<BLoadTranspose> = {})
{
// Load tile from LDS
// Do not use load_int4_tile here because it will have support to cast from fp4 to
// compute type, while here we want to only load from LDS and then apply the scale
// and cast later
if constexpr(ALoadTranspose)
{
a_warp_tile_ = load_tile_transpose(a_block_window);
}
else
{
load_tile(a_warp_tile_, a_block_window);
}
if constexpr(BLoadTranspose)
{
b_warp_tile_lds_ = load_tile_transpose(b_block_window);
}
else
{
load_tile(b_warp_tile_lds_, b_block_window);
}
// Apply scale and cast
using BDataTypeRaw =
std::conditional_t<std::is_same_v<BDataType, pk_fp4_t>, pk_fp4_t::type, BDataType>;
constexpr index_t warp_size = get_warp_size();
constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size;
constexpr index_t thread_buffer_size = nelements / UnaryOpSize_;
const element_wise::DequantPack8 elementwise_op{};
using SrcVectorRawType = ext_vector_t<BDataTypeRaw, UnaryOpSize_ / BPackedSize>;
using DstVectorType = ext_vector_t<ComputeDataType, UnaryOpSize_>;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
// B scale register offset
constexpr index_t reg_offset = [&]() {
if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN))
return ((nIter * NWarp * WarpGemm::kN) /
GemmTraits::BQuantGroupSize::kN) *
Traits::KQPerBlock +
kQScale;
else
{
return nIter * Traits::KQPerBlock + kQScale;
}
}();
// Get B scale from thread buffer
auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
float b_scale_f = float(scale_reg);
static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
// Thread buffers
using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)));
BWarpThreadBuffer b_warp_thread_buffer;
BLDSThreadBuffer b_lds_thread_buffer;
// Load thread buffer from tile (LDS type)
b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// Apply scale to B thread buffer and cast
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
elementwise_op(
b_warp_thread_buffer.template get_as<DstVectorType>()(i),
b_lds_thread_buffer.template get_as<SrcVectorRawType>()[i],
b_scale_f);
});
// Store B thread buffer to tile (MMA type)
b_warp_tile_.set_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths),
b_warp_thread_buffer);
});
});
});
}
// C += A * B
template <typename CBlockTensor,
typename BQBlockTensor,
@@ -400,6 +482,7 @@ struct BQuantBlockUniversalGemmAsBsCr
MakeCBlockTile();
}
// Read A and B from LDS
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
bool ALoadTranspose = false,
@@ -412,7 +495,24 @@ struct BQuantBlockUniversalGemmAsBsCr
block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr);
}
// Read A and B from LDS and apply scale to B
template <typename ASmemBlockWindow,
typename BSmemBlockWindow,
typename BQRegBlockTile,
bool ALoadTranspose = false,
bool BLoadTranspose = false>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window,
BQRegBlockTile bq_block_tile,
bool_constant<ALoadTranspose> a_load_tr = {},
bool_constant<BLoadTranspose> b_load_tr = {})
{
block_gemm_impl_.LocalPrefetch(
a_block_window, b_block_window, bq_block_tile, a_load_tr, b_load_tr);
}
// C += A * B
// Apply scale after MMA
template <typename CBlockTensor,
typename BQBlockTensor,
typename ASmemBlockWindow,
@@ -425,6 +525,16 @@ struct BQuantBlockUniversalGemmAsBsCr
block_gemm_impl_(c_block_tensor, bq_block_tensor, a_block_window, b_block_window);
}
// C += A * B
// Scale has already been applied to B, so this is using the gemm universal block implementation
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};

View File

@@ -787,20 +787,12 @@ struct QuantGemmKernel
}
else
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, k_size / 2),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
else
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, k_size),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, k_size),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
@@ -814,16 +806,10 @@ struct QuantGemmKernel
}
else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
sequence<false, GemmPipeline::kPadK>{});
else
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
@@ -848,17 +834,10 @@ struct QuantGemmKernel
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
{i_n, 0});
else
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{

View File

@@ -10,7 +10,7 @@
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
struct GemmMicroscalePipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
@@ -42,10 +42,14 @@ struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Probl
CK_TILE_DEVICE constexpr auto
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlockBQ>;
using XPerTile = number<KPerBlockBQ>;
using YPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<NPerBlockBQ>,
number<KPerBlockBQ>>;
using XPerTile =
std::conditional_t<std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>,
number<KPerBlockBQ>,
number<NPerBlockBQ>>;
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),

View File

@@ -0,0 +1,296 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "gemm_group_quant_utils.hpp"
namespace ck_tile {
struct GemmMicroscalePipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base::I0;
using Base::I1;
using Base::I2;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
// Support both RowMajor and ColumnMajor layouts for BQ
if constexpr(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetABQGlobalVectorLoadSize<Problem, BQDataType, KPerBlockBQ, NPerBlockBQ>();
}
else
{
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
// If we apply scale before writing to LDS, we need a tile distribution for
// BQuant consistent with global memory reading of matrix B, while
// if we apply scale after reading from LDS, we need a tile distribution for
// BQuant consistent with the MMA instructions layout
if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead)
{
using BlockGemmShape = typename Problem::BlockGemmShape;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using TileEncodingPattern =
tile_distribution_encoding_pattern_bq<BlockGemmShape,
WarpGemm,
BlockSize,
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size);
constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize;
constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK;
// For each BQ layout we need different encodings whether B has the same layout or not
// TODO: generalize encodings for different BQuantGroupSize granularity
if constexpr(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(std::is_same_v<BQLayout, BLayout>)
{
constexpr index_t K0 = KPerBlock / b_vec;
constexpr index_t K1 = K0 / KScale;
constexpr index_t K3 = KScale;
constexpr index_t K2 = 1;
constexpr index_t N0 = num_warps / NumWaveGroups;
constexpr index_t N1 = warp_size / K0;
constexpr index_t N2 = NPerBlock / (N0 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<K1>,
tuple<sequence<N0, N1, N2>, sequence<K3, K2>>,
tuple<sequence<1>, sequence<1, 2, 0>>,
tuple<sequence<0>, sequence<1, 0, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
else
{
constexpr index_t N1 = NPerBlock / b_vec;
constexpr index_t N2 = b_vec;
constexpr index_t KRepeatInWave = warp_size / N1;
constexpr index_t KRepeatAcrossWave = num_warps / KScale;
constexpr index_t K2 = num_warps / KRepeatAcrossWave;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<KRepeatAcrossWave, KRepeatInWave>,
tuple<sequence<1, N1, N2>, sequence<K2, 1, 1>>,
tuple<sequence<1, 2, 0>, sequence<0, 1, 2>>,
tuple<sequence<0, 0, 0>, sequence<1, 1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
}
else
{
if constexpr(std::is_same_v<BQLayout, BLayout>)
{
constexpr index_t NScale = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t N0 = NScale / b_vec;
constexpr index_t N1 = b_vec;
constexpr index_t KLanes = warp_size / N0;
constexpr index_t KVec = KPerBlock / KLanes / num_warps;
constexpr index_t KRepeat = KPerBlock / KScale / KVec;
constexpr index_t KRepeatInWave = KRepeat > KLanes ? KLanes : 1;
constexpr index_t KRepeatAcrossWave = KRepeat > KLanes ? KRepeat / KLanes : 1;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<KRepeatAcrossWave, KRepeatInWave>,
tuple<sequence<1, 1, 1>, sequence<N0, N1>>,
tuple<sequence<1, 0>, sequence<1, 0, 2>>,
tuple<sequence<0, 0>, sequence<1, 1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
else
{
constexpr index_t KRepeatInWave = Problem::BQuantGroupSize::kK / b_vec;
constexpr index_t K1 = KScale;
constexpr index_t N0 = num_warps / NumWaveGroups;
constexpr index_t N1 = warp_size / (KRepeatInWave * K1);
// Number of contiguous elements in N dimension when reading B matrix
// becomes the vector size of BQ
constexpr index_t N2 = NPerBlock / (BlockSize / (KPerBlock / b_vec));
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1, 1, KRepeatInWave>,
tuple<sequence<1, K1, 1>, sequence<N0, N1, N2>>,
tuple<sequence<1, 0, 2>, sequence<2, 0, 1, 0>>,
tuple<sequence<0, 0, 0>, sequence<1, 1, 1, 2>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
}
}
}
// Return AttrNumAccess for a given warp tile (defined by ThreadElements) and data type
template <typename DataType, bool UseLoadTranspose, index_t ThreadElements>
static constexpr auto GetAttrNumAccess(bool_constant<UseLoadTranspose>, number<ThreadElements>)
{
constexpr index_t PackedSize = numeric_traits<remove_cvref_t<DataType>>::PackedSize;
constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(DataType) * PackedSize;
return !UseLoadTranspose ? WGAttrNumAccessEnum::Single
: vector_size == ThreadElements ? WGAttrNumAccessEnum::Single
: vector_size * 2 == ThreadElements ? WGAttrNumAccessEnum::Double
: vector_size * 4 == ThreadElements ? WGAttrNumAccessEnum::Quad
: WGAttrNumAccessEnum::Invalid;
};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using ComputeDataType = typename Problem::ComputeDataType;
using LDSADataType = typename Problem::ADataType;
using LDSBDataType = std::conditional_t<Problem::BCastPolicy == CastPolicy::BeforeLDSWrite,
ComputeDataType,
typename Problem::BDataType>;
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize!");
constexpr auto thread_elements =
number<WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()>{};
constexpr auto is_a_load_tr_v = bool_constant<Base::template is_a_load_tr<Problem>>{};
constexpr auto is_b_load_tr_v = bool_constant<Base::template is_b_load_tr<Problem>>{};
constexpr auto is_any_load_tr = is_a_load_tr_v || is_b_load_tr_v;
constexpr auto wg_attr_num_access_compute =
GetAttrNumAccess<ComputeDataType>(is_any_load_tr, thread_elements);
constexpr auto wg_attr_num_accessA =
std::is_same_v<LDSADataType, LDSBDataType>
? wg_attr_num_access_compute
: GetAttrNumAccess<LDSADataType>(is_a_load_tr_v, thread_elements);
constexpr auto wg_attr_num_accessB =
std::is_same_v<LDSADataType, LDSBDataType>
? wg_attr_num_access_compute
: GetAttrNumAccess<LDSBDataType>(is_b_load_tr_v, thread_elements);
using WarpGemm = WarpGemmDispatcher<ComputeDataType,
ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_accessA,
wg_attr_num_accessB>;
static_assert(is_any_of<ComputeDataType, fp8_t, bf8_t, bf16_t, fp16_t>::value);
static_assert(std::is_same_v<typename Problem::CDataType, float>);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_t>,
typename Problem::ADataType,
typename Problem::BDataType>,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BQuantBlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -9,7 +9,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_microscale_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
@@ -18,15 +18,21 @@ namespace ck_tile {
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmMxFp4PipelineAgBgCrPolicy>
struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
template <typename Problem, typename Policy = GemmMicroscalePipelineAgBgCrPolicy>
struct MicroscaleGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
using PipelineImplBase = GemmMicroscalePipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
using BLDSType = std::conditional_t<IsCastBeforeLDS, BDqDataType, BDataType>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
@@ -40,12 +46,16 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDqDataType>>::PackedSize;
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
static constexpr index_t BLDSPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BLDSType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
@@ -82,6 +92,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
@@ -165,6 +178,11 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
{
using Base = PipelineImplBase;
static constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
static constexpr bool is_b_row_major =
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
@@ -207,7 +225,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16
B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
@@ -223,7 +241,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4;
B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
@@ -306,6 +324,197 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
});
}
template <typename TileType, typename CastTileType, typename ScaleTileType>
CK_TILE_DEVICE static void ScaleTile(const TileType& block_tile,
CastTileType& block_tile_cast,
const ScaleTileType& scale_tile)
{
if constexpr(IsCastBeforeLDS)
{
constexpr auto b_block = TileType::get_distributed_spans();
// Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4
// on gfx950
auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) {
if constexpr(std::is_same_v<BDqDataType, half_t>)
{
return pk_fp4_to_fp16x2(pk_mxfp4, fscale);
}
else if constexpr(std::is_same_v<BDqDataType, bf16_t>)
{
return pk_fp4_to_bf16x2(pk_mxfp4, fscale);
}
else
{
static_assert(false, "unsupported compute type");
}
};
constexpr index_t BQuantGroupSizeIdx0 =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>
? BQuantGroupSize::kN
: BQuantGroupSize::kK;
constexpr index_t BQuantGroupSizeIdx1 =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>
? BQuantGroupSize::kK
: BQuantGroupSize::kN;
// The input indices are with respect to B block tile. If B and Bq have different
// layouts, the indices must be swapped
auto make_bq_index = [](auto idx0, auto idx1) {
if constexpr(std::is_same_v<BLayout, BQLayout>)
{
return make_tuple(
tile_distributed_index<idx0.impl_.at(0) / BQuantGroupSizeIdx0>{},
tile_distributed_index<idx1.impl_.at(0) / BQuantGroupSizeIdx1>{});
}
else
{
return make_tuple(
tile_distributed_index<idx1.impl_.at(0) / BQuantGroupSizeIdx0>{},
tile_distributed_index<idx0.impl_.at(0) / BQuantGroupSizeIdx1>{});
}
};
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
if constexpr(idx1.impl_.at(0) % BPackedSize == 0)
{
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0)>{};
constexpr auto idx1_hi =
tile_distributed_index<idx1.impl_.at(0) + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
auto b_pack = block_tile[i_j_idx];
constexpr auto i_j_idx_scale_lo = make_bq_index(idx0, idx1_lo);
constexpr auto i_j_idx_scale_hi = make_bq_index(idx0, idx1_hi);
// If the scale is the same for packed values, use pk cvt scale
// instructions, otherwise scale and cast element by element
if constexpr(i_j_idx_scale_lo[I0{}].impl_.at(0) ==
i_j_idx_scale_hi[I0{}].impl_.at(0) &&
i_j_idx_scale_lo[I1{}].impl_.at(0) ==
i_j_idx_scale_hi[I1{}].impl_.at(0))
{
float scale = float(scale_tile[i_j_idx_scale_lo]);
auto cvt = pk_mxfp4_to_compute_v2(b_pack, scale);
block_tile_cast(i_j_idx_lo) = cvt.x;
block_tile_cast(i_j_idx_hi) = cvt.y;
}
else
{
float scale_lo = float(scale_tile[i_j_idx_scale_lo]);
auto b_f4_lo =
type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
block_tile_cast(i_j_idx_lo) = type_convert<BDqDataType>(
type_convert<float>(b_f4_lo) * scale_lo);
float scale_hi = float(scale_tile[i_j_idx_scale_hi]);
auto b_f4_hi =
type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
block_tile_cast(i_j_idx_hi) = type_convert<BDqDataType>(
type_convert<float>(b_f4_hi) * scale_hi);
}
}
}
else
{
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_bq_index(idx0, idx1);
float scale = float(scale_tile[i_j_idx_scale]);
auto b_pack = block_tile[i_j_idx];
block_tile_cast(i_j_idx) =
type_convert<BDqDataType>(type_convert<float>(b_pack) * scale);
}
});
});
}
}
template <typename WindowType, typename TileType, typename ElementwiseFunc>
CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window,
const TileType& block_tile,
const ElementwiseFunc& element_func) const
{
if constexpr(is_a_col_major && !is_a_load_tr_v())
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, block_tile);
Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func);
}
else
{
Base::LocalPrefill(lds_window, block_tile, element_func);
}
}
template <typename WindowType,
typename TileType,
typename TileTypeCast,
typename ElementwiseFunc>
CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window,
const TileType& block_tile,
const TileTypeCast& block_tile_cast,
const ElementwiseFunc& element_func) const
{
// Fill LDS and apply the scale if IsCastBeforeLDS
auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) {
if constexpr(IsCastBeforeLDS)
{
return b_block_tile_cast;
}
else
{
return b_block_tile_orig;
}
};
if constexpr(is_b_row_major && !is_b_load_tr_v())
{
auto b_shuffle_tmp = make_static_distributed_tensor<BLDSType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast));
Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func);
}
else
{
Base::LocalPrefill(
lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func);
}
}
template <typename BlockGemmType,
typename AWindowType,
typename BWindowType,
typename QTileType>
CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm,
const AWindowType& a_lds_window,
const BWindowType& b_lds_window,
const QTileType& q_block_tile) const
{
// Load from LDS
// It can apply the scale and cast if we scale after reading from LDS
if constexpr(IsCastBeforeLDS)
{
block_gemm.LocalPrefetch(
a_lds_window, b_lds_window, is_a_load_tr_v, is_b_load_tr_v);
}
else
{
block_gemm.LocalPrefetch(
a_lds_window, b_lds_window, q_block_tile, is_a_load_tr_v, is_b_load_tr_v);
}
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
@@ -321,6 +530,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
index_t num_loop,
void* p_smem) const
{
// -----------------------------------------------------------------------------------------
// Pipeline checks
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
@@ -330,15 +541,14 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
"A/B/BQ Dram block window should have the same data type as appropriate "
"([A|B|BQ]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
static_assert(is_bq_col_major
? (NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major
@@ -347,13 +557,12 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(
is_b_row_major
? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
@@ -380,20 +589,19 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// B scale DRAM tile window for load
// auto b_scale_copy_dram_window =
// make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
// bq_dram_block_window_tmp.get_window_lengths(),
// bq_dram_block_window_tmp.get_window_origin(),
// Policy::template GetBQDramLoadWindow<Problem>());
auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){};
// This defines the scaled and casted block tile for B matrix.
// Effectively, it is used only if we scale and cast before writing to LDS.
auto bdq_block_tile = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeBRegTileDistribution<Problem>());
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
// using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
@@ -402,114 +610,61 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_fp4_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / BQuantGroupSize::kK;
constexpr BQDramTileWindowStep b_scale_dram_tile_window_step =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>
? make_array(0, KPerBlock / BQuantGroupSize::kK)
: make_array(KPerBlock / BQuantGroupSize::kK, 0);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
// auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){};
// prefetch stages
// Vmem -> Vgpr 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// BDataType
auto b_block_tile = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeBRegTileDistribution<Problem>());
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Vmem -> Vgpr 0 (Q matrix)
// Scale and cast tile before writing to LDS (if IsCastBeforeLDS)
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step);
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
constexpr auto idx1_js = tile_distributed_index<0>{};
constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans();
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto b_scale_uint = type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
constexpr auto idx1_hi = tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
b_block_tile(i_j_idx_lo) =
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
b_block_tile(i_j_idx_hi) =
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
});
});
// initialize C
// initialize C tile to zero
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
// Vgpr -> LDS 0
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
// Vmem -> Vgpr 1
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto b_scale_uint = type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
constexpr auto idx1_hi = tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
b_block_tile(i_j_idx_lo) =
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
b_block_tile(i_j_idx_hi) =
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
});
});
// If we scale and cast before writing to LDS,
// we need to read another tile of Q matrix from Vmem, then scale and cast tile
if constexpr(IsCastBeforeLDS)
{
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step);
}
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// LDS -> Vgpr 0
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
__builtin_amdgcn_sched_barrier(0);
@@ -521,72 +676,34 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
{
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
// Vgpr -> LDS
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
// Vmem -> Vgpr
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// Vmem -> Vgpr (Q matrix)
// Scale and cast tile before writing to LDS (if IsCastBeforeLDS)
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto b_scale_uint =
type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
constexpr auto idx1_hi =
tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
b_block_tile(i_j_idx_lo) =
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
b_block_tile(i_j_idx_hi) =
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
});
});
move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step);
ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile);
// Consume tile
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// LDS -> Vgpr
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
// b_block_stride +=1;
} while(i < (num_loop - 1));
}
// tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile);
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
@@ -596,35 +713,31 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
}
else
{
// If we scale and cast after reading from LDS,
// we didn't read the second tile of Q matrix from Vmem during prefetch stages,
// so we need to read the last tile here.
// This is not a problem because we have all block_gemm instructions to hide the
// latency.
if constexpr(!IsCastBeforeLDS)
{
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, b_scale_dram_tile_window_step);
}
// Consume second to last tile
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
// Vgpr -> LDS last tile
ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
// LDS -> Vgpr last tile
LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile);
// Consume last tile
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
}
@@ -653,9 +766,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Probl
ck_tile::ignore = n;
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
identity{},
b_dram_block_window_tmp,
[](const BDqDataType& b) { return b; },
identity{},
bq_dram_block_window_tmp,
num_loop,
p_smem);

View File

@@ -1,140 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "gemm_group_quant_utils.hpp"
namespace ck_tile {
struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base::I0;
using Base::I1;
using Base::I2;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
{
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size);
constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize;
constexpr index_t K0 = KPerBlock / b_vec;
constexpr index_t K1 = K0 / KScale;
constexpr index_t K3 = K0 / K1;
constexpr index_t K2 = 1;
constexpr index_t N0 = num_warps / NumWaveGroups;
constexpr index_t N1 = warp_size / K0;
constexpr index_t N2 = NPerBlock / (N0 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<K1>,
tuple<sequence<N0, N1, N2>, sequence<K3, K2>>,
tuple<sequence<1>, sequence<1, 2, 0>>,
tuple<sequence<0>, sequence<1, 0, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::BQuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf16_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -24,7 +24,8 @@ template <typename ADataType_,
typename ComputeDataType_ = void,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
TailNumber TailNum_ = TailNumber::Full,
CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead>
struct GemmQuantPipelineProblemBase
: public GemmPipelineProblemBase<
ADataType_,
@@ -82,6 +83,20 @@ struct GemmQuantPipelineProblemBase
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
// gfx950 supports load with transpose for 4bit types, so we can transpose
// pk_fp4_t from LDS in registers. But without this instruction,
// the transpose is done in register between Vmem read and LDS write and
// the implementation does not support 4 bit types
#ifdef __gfx950__
static constexpr auto BCastPolicy = BCastPolicy_;
#else
static constexpr auto BCastPolicy =
std::is_same_v<BDataType, pk_fp4_t> &&
std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>
? CastPolicy::BeforeLDSWrite
: BCastPolicy_;
#endif
static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0);
static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0);
static_assert(BlockGemmShape::kM % BQuantGroupSize::kM == 0);
@@ -155,7 +170,8 @@ template <typename ADataType_,
typename ComputeDataType_ = ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
TailNumber TailNum_ = TailNumber::Full,
CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead>
using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
void, // no AQDataType for BQuant
BDataType_,
@@ -169,7 +185,8 @@ using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase<ADataType_,
ComputeDataType_,
Scheduler_,
HasHotLoop_,
TailNum_>;
TailNum_,
BCastPolicy_>;
template <typename ADataType_,
typename AQDataType_,