mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement ## Motivation Reduce the scale loading size and also has better utilization of MFMA scale selection. ## Technical Details Add up the packing of mx scales. ## Test Plan Use the existing test cases. ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
859acb5ae7
commit
5f90f69795
@@ -272,7 +272,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
|
||||
if(!preshuffle && GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
if constexpr(GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -14,7 +14,56 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_v
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
// Use e8m0_t directly without packing - simpler and cleaner approach
|
||||
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
|
||||
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
|
||||
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
|
||||
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
|
||||
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
|
||||
template <ck_tile::index_t MNPack = 2,
|
||||
ck_tile::index_t KPack = 2,
|
||||
ck_tile::index_t XdlMNThread = 16,
|
||||
ck_tile::index_t XdlKThread = 4>
|
||||
auto packScalesMNxK(const ck_tile::HostTensor<ck_tile::e8m0_t>& src, bool kLast)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1];
|
||||
const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
|
||||
const ck_tile::index_t MN_packed = MN / MNPack;
|
||||
const ck_tile::index_t K_packed = K_scale / KPack;
|
||||
const ck_tile::index_t total_packed = MN_packed * K_packed;
|
||||
|
||||
// Output as flat vector of int32_t (row-major [MN/MNPack, K/32/KPack])
|
||||
std::vector<int32_t> packed(total_packed);
|
||||
|
||||
for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
|
||||
{
|
||||
for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++)
|
||||
{
|
||||
int32_t val = 0;
|
||||
ck_tile::index_t mn_lane = packed_mn % XdlMNThread;
|
||||
ck_tile::index_t mn_group = packed_mn / XdlMNThread;
|
||||
ck_tile::index_t k_lane = packed_k % XdlKThread;
|
||||
ck_tile::index_t k_group = packed_k / XdlKThread;
|
||||
for(ck_tile::index_t ik = 0; ik < KPack; ik++)
|
||||
{
|
||||
for(ck_tile::index_t imn = 0; imn < MNPack; imn++)
|
||||
{
|
||||
ck_tile::index_t byteIdx = ik * MNPack + imn;
|
||||
ck_tile::index_t orig_mn =
|
||||
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
|
||||
ck_tile::index_t orig_k =
|
||||
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
|
||||
|
||||
ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
|
||||
val |= (static_cast<int32_t>(v.get()) << (byteIdx * 8));
|
||||
}
|
||||
}
|
||||
packed[packed_mn * K_packed + packed_k] = val;
|
||||
}
|
||||
}
|
||||
return packed;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -101,21 +150,43 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout)
|
||||
break;
|
||||
}
|
||||
|
||||
// Device buffers for A, B, C, and scale tensors
|
||||
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
|
||||
constexpr ck_tile::index_t MPerXdl_ = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl_ = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl_ = GemmConfig::K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp_ = GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl_);
|
||||
constexpr ck_tile::index_t NIterPerWarp_ = GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl_);
|
||||
constexpr ck_tile::index_t KIterPerWarp_ = GemmConfig::K_Tile / KPerXdl_;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff = (MIterPerWarp_ >= 2 && MIterPerWarp_ % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff = (NIterPerWarp_ >= 2 && NIterPerWarp_ % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff = (KIterPerWarp_ >= 2 && KIterPerWarp_ % 2 == 0) ? 2 : 1;
|
||||
|
||||
// Pack scales: [M, K/32] e8m0_t → [M/MXdlPackEff, K/32/KXdlPackEff] int32_t
|
||||
// Original unpacked tensors are kept for CPU reference validation
|
||||
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
auto scale_a_packed =
|
||||
packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host, true);
|
||||
auto scale_b_packed =
|
||||
packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host, false);
|
||||
|
||||
// Device buffers for A, B, C, and packed scale tensors
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host.data());
|
||||
c_dev_buf.SetZero(); // Initialize C buffer to zero
|
||||
scale_a_dev_buf.ToDevice(scale_a_host.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_packed.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_packed.data());
|
||||
|
||||
// Scale pointers - use e8m0_t* directly
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>; // in blocks of 32 in K
|
||||
// Scale pointers - point to packed int32_t data, kernel reinterprets as int32_t*
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
@@ -249,14 +249,19 @@ struct BlockGemmARegBRegCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B with MX scaling
|
||||
// ScaleATensor: [MIterPerWarp, KIterPerWarp] -> int32_t
|
||||
// ScaleBTensor: [NIterPerWarp, KIterPerWarp] -> int32_t
|
||||
// C += A * B with MX scaling and packed-in-two (XdlPack) optimization
|
||||
// Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t
|
||||
// values (for A) or NXdlPack * KXdlPack (for B), packed on the host.
|
||||
// Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call.
|
||||
// XdlPack template parameters default to 2; fall back to 1 when iteration count is too small.
|
||||
template <typename CBlockTensor,
|
||||
typename ABlockTensor,
|
||||
typename BBlockTensor,
|
||||
typename ScaleATensor,
|
||||
typename ScaleBTensor>
|
||||
typename ScaleBTensor,
|
||||
index_t MXdlPack_ = 2,
|
||||
index_t NXdlPack_ = 2,
|
||||
index_t KXdlPack_ = 2>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor,
|
||||
@@ -304,53 +309,88 @@ struct BlockGemmARegBRegCRegV1
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop with MX scaling:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
// Effective XdlPack: fall back to 1 when iteration count is insufficient
|
||||
constexpr index_t MXdlPack =
|
||||
(MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1;
|
||||
constexpr index_t NXdlPack =
|
||||
(NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1;
|
||||
constexpr index_t KXdlPack =
|
||||
(KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1;
|
||||
|
||||
// get A scale for this M-K tile using get_y_sliced_thread_data
|
||||
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
// hot loop with MX scaling and pre-packed int32_t scales:
|
||||
// Outer loops iterate over pack groups (scale tile indices)
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
// Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t)
|
||||
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
|
||||
sequence<kIter, mIter, 0>{}, sequence<1, 1, 1>{});
|
||||
const auto a_scale_e8m0 = scale_a_slice[number<0>{}];
|
||||
const int32_t a_scale = static_cast<int32_t>(a_scale_e8m0.get());
|
||||
sequence<ikpack, impack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t a_scale_packed = bit_cast<int32_t>(scale_a_slice[number<0>{}]);
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// get B scale for this N-K tile using get_y_sliced_thread_data
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
// Get pre-packed int32_t B scale
|
||||
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
|
||||
sequence<kIter, nIter, 0>{}, sequence<1, 1, 1>{});
|
||||
const auto b_scale_e8m0 = scale_b_slice[number<0>{}];
|
||||
const int32_t b_scale = static_cast<int32_t>(b_scale_e8m0.get());
|
||||
sequence<ikpack, inpack, 0>{}, sequence<1, 1, 1>{});
|
||||
const int32_t b_scale_packed = bit_cast<int32_t>(scale_b_slice[number<0>{}]);
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::
|
||||
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
// Inner loops: issue MFMAs within the pack group using OpSel
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto kIter = ikpack * KXdlPack + ikxdl;
|
||||
constexpr auto mIter = impack * MXdlPack + imxdl;
|
||||
|
||||
// warp GEMM with MX scaling
|
||||
// Cast e8m0_t to int32_t, use OpSel=0 (least significant byte)
|
||||
constexpr index_t kOpSel = 0; // Always use OpSel=0
|
||||
WarpGemm{}.template operator()<kOpSel, kOpSel>(
|
||||
c_warp_tensor, a_warp_tensor, b_warp_tensor, a_scale, b_scale);
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() =
|
||||
a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
// OpSel for A: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl;
|
||||
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto nIter = inpack * NXdlPack + inxdl;
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() =
|
||||
b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{},
|
||||
b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// OpSel for B: selects byte within packed int32_t
|
||||
constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
using c_iter_idx = std::conditional_t<TransposeC,
|
||||
sequence<nIter, mIter>,
|
||||
sequence<mIter, nIter>>;
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM with MX scaling using pre-packed scale and OpSel
|
||||
WarpGemm{}.template operator()<kOpSelA, kOpSelB>(c_warp_tensor,
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
a_scale_packed,
|
||||
b_scale_packed);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -118,7 +118,12 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
// We are not storing the original packed type in LDS, so we need to multiply the smem size
|
||||
// by the packed size.
|
||||
constexpr index_t smem_size_a = Policy::template GetSmemSizeA<Problem>() * APackedSize;
|
||||
constexpr index_t smem_size_b = Policy::template GetSmemSizeB<Problem>() * BPackedSize;
|
||||
|
||||
return 2 * (smem_size_a + smem_size_b);
|
||||
}
|
||||
|
||||
static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp;
|
||||
|
||||
@@ -98,6 +98,30 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
static constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
// XdlPack: desired packing of e8m0_t scale values into int32_t
|
||||
static constexpr index_t MXdlPack = 2;
|
||||
static constexpr index_t NXdlPack = 2;
|
||||
static constexpr index_t KXdlPack = 2;
|
||||
|
||||
// Effective pack sizes: fall back to 1 when dimension is too small
|
||||
using BlockWarps_ = typename BlockGemmShape::BlockWarps;
|
||||
static constexpr index_t MPerBlock_ = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock_ = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock_ = BlockGemmShape::kK;
|
||||
static constexpr index_t MWarp_ = BlockWarps_::at(number<0>{});
|
||||
static constexpr index_t NWarp_ = BlockWarps_::at(number<1>{});
|
||||
static constexpr index_t KPerXdl_ = BlockGemmShape::WarpTile::at(number<2>{});
|
||||
static constexpr index_t MIterPerWarp_ = MPerBlock_ / (MWarp_ * MThreadPerXdl);
|
||||
static constexpr index_t NIterPerWarp_ = NPerBlock_ / (NWarp_ * NThreadPerXdl);
|
||||
static constexpr index_t KIterPerWarp_ = KPerBlock_ / KPerXdl_;
|
||||
|
||||
static constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp_ >= MXdlPack && MIterPerWarp_ % MXdlPack == 0) ? MXdlPack : 1;
|
||||
static constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp_ >= NXdlPack && NIterPerWarp_ % NXdlPack == 0) ? NXdlPack : 1;
|
||||
static constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp_ >= KXdlPack && KIterPerWarp_ % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
@@ -245,7 +269,9 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
return c_block_window;
|
||||
}
|
||||
|
||||
// Create scale A block windows following the pattern of MakeABlockWindows
|
||||
// Create scale A block windows with packed int32_t layout
|
||||
// Host packs 2M x 2K e8m0_t values into one int32_t
|
||||
// Tensor view: [M/MXdlPack, K/32/KXdlPack] of int32_t
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto MakeScaleABlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const index_t i_m)
|
||||
@@ -253,28 +279,28 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleM::GranularityK;
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPackEff;
|
||||
const auto scale_m_packed = kargs.M / MXdlPackEff;
|
||||
|
||||
// A scale tensor view - layout [M, scale_k_size] with e8m0_t elements
|
||||
// Use e8m0_t directly without packing
|
||||
// A scale tensor view - layout [M/MXdlPackEff, K/32/KXdlPackEff] with int32_t elements
|
||||
const auto scale_a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_a.ptr),
|
||||
make_tuple(kargs.M, scale_k_size),
|
||||
make_tuple(scale_k_size, 1));
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr),
|
||||
make_tuple(scale_m_packed, scale_k_packed),
|
||||
make_tuple(scale_k_packed, 1));
|
||||
|
||||
// Create block window for scale A
|
||||
// K dimension: scale_k_size e8m0_t elements
|
||||
// i_m is element offset (iM * MPerBlock), not tile index
|
||||
auto scale_a_block_window =
|
||||
make_tile_window(scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
|
||||
{i_m, 0});
|
||||
// Tile window shape: [MPerBlock/MXdlPackEff, KPerBlock/32/KXdlPackEff]
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
|
||||
{i_m / MXdlPackEff, 0});
|
||||
|
||||
return scale_a_block_window;
|
||||
}
|
||||
|
||||
// Create scale B block windows following the pattern of MakeBBlockWindows
|
||||
// Create scale B block windows with packed int32_t layout
|
||||
// Host packs 2N x 2K e8m0_t values into one int32_t
|
||||
// Tensor view: [N/NXdlPack, K/32/KXdlPack] of int32_t
|
||||
template <typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE static auto MakeScaleBBlockWindows(const KernelArgs<ScaleM, ScaleN>& kargs,
|
||||
const index_t i_n)
|
||||
@@ -282,23 +308,21 @@ struct MXGemmKernel : UniversalGemmKernel<TilePartitioner_, MXGemmPipeline_, Epi
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = ScaleN::GranularityK;
|
||||
const auto scale_k_size = kargs.K / BlockScaleSize;
|
||||
const auto scale_k_packed = kargs.K / BlockScaleSize / KXdlPackEff;
|
||||
const auto scale_n_packed = kargs.N / NXdlPackEff;
|
||||
|
||||
// B scale tensor view
|
||||
// Host stores as [K/32, N] col-major = [N, K/32] row-major from access perspective
|
||||
// B scale tensor view - [N/NXdlPackEff, K/32/KXdlPackEff] of int32_t
|
||||
const auto scale_b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const e8m0_t*>(scale_b.ptr),
|
||||
make_tuple(kargs.N, scale_k_size), // [N, K/32] for access
|
||||
make_tuple(scale_k_size, 1)); // stride to match col-major storage
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr),
|
||||
make_tuple(scale_n_packed, scale_k_packed),
|
||||
make_tuple(scale_k_packed, 1));
|
||||
|
||||
// Create block window for scale B
|
||||
// Tile window shape matches access pattern: [NPerBlock, KPerBlock/32]
|
||||
// i_n is element offset (iN * NPerBlock)
|
||||
auto scale_b_block_window =
|
||||
make_tile_window(scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize>{}),
|
||||
{i_n, 0});
|
||||
// Tile window shape: [NPerBlock/NXdlPackEff, KPerBlock/32/KXdlPackEff]
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPackEff>{},
|
||||
number<TilePartitioner::KPerBlock / BlockScaleSize / KXdlPackEff>{}),
|
||||
{i_n / NXdlPackEff, 0});
|
||||
|
||||
return scale_b_block_window;
|
||||
}
|
||||
|
||||
@@ -315,14 +315,36 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
////////////// MX Scale windows /////////////////
|
||||
////////////// MX Scale windows (pre-packed int32_t) /////////////////
|
||||
// Get WarpGemm configuration
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t MWarp = BlockWarps::at(I0{});
|
||||
constexpr index_t NWarp = BlockWarps::at(I1{});
|
||||
|
||||
// Calculate scale dimensions: KPerBlock elements need KPerBlock/32 e8m0_t scales
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize;
|
||||
// Compute effective XdlPack sizes (fall back to 1 when iter count < pack)
|
||||
constexpr index_t MPerXdl = WarpTile::at(I0{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(I1{});
|
||||
constexpr index_t KPerXdl = WarpTile::at(I2{});
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
|
||||
constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp >= Policy::MXdlPack && MIterPerWarp % Policy::MXdlPack == 0)
|
||||
? Policy::MXdlPack
|
||||
: 1;
|
||||
constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp >= Policy::NXdlPack && NIterPerWarp % Policy::NXdlPack == 0)
|
||||
? Policy::NXdlPack
|
||||
: 1;
|
||||
constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= Policy::KXdlPack && KIterPerWarp % Policy::KXdlPack == 0)
|
||||
? Policy::KXdlPack
|
||||
: 1;
|
||||
|
||||
// Packed scale dimensions
|
||||
constexpr index_t ScaleKDimPerBlock = KPerBlock / ScaleBlockSize / KXdlPackEff;
|
||||
|
||||
// Scale tensor views and base origins for creating tile windows per iteration
|
||||
const auto& scale_a_tensor_view = scale_a_window.get_bottom_tensor_view();
|
||||
@@ -330,18 +352,18 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
auto scale_a_base_origin = scale_a_window.get_window_origin();
|
||||
auto scale_b_base_origin = scale_b_window.get_window_origin();
|
||||
|
||||
// Create sample scale windows to determine tile types
|
||||
auto scale_a_dram_window =
|
||||
make_tile_window(scale_a_tensor_view,
|
||||
make_tuple(number<MPerBlock>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_base_origin,
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
// Create scale windows with packed int32_t dimensions
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
scale_a_tensor_view,
|
||||
make_tuple(number<MPerBlock / MXdlPackEff>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_a_base_origin,
|
||||
Policy::template MakeMX_ScaleA_DramTileDistribution<Problem>());
|
||||
|
||||
auto scale_b_dram_window =
|
||||
make_tile_window(scale_b_tensor_view,
|
||||
make_tuple(number<NPerBlock>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_b_base_origin,
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_tensor_view,
|
||||
make_tuple(number<NPerBlock / NXdlPackEff>{}, number<ScaleKDimPerBlock>{}),
|
||||
scale_b_base_origin,
|
||||
Policy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
@@ -427,8 +449,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
"SmemSizeB size is wrong!");
|
||||
|
||||
////////////// MX Scale register tiles (ping-pong buffers) /////////////////
|
||||
// No packing needed - each thread gets e8m0_t elements directly
|
||||
// Each thread will cast e8m0_t to int32_t for WarpGemm with OpSel=0
|
||||
// Scales are pre-packed int32_t: each int32_t holds 2M/N x 2K e8m0_t values
|
||||
// Block GEMM uses OpSel (0-3) to select the right byte per MFMA call
|
||||
|
||||
using ScaleATileType = decltype(load_tile(scale_a_dram_window));
|
||||
using ScaleBTileType = decltype(load_tile(scale_b_dram_window));
|
||||
|
||||
@@ -131,7 +131,15 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// MX Scale tile distributions for loading from global memory
|
||||
// XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension
|
||||
// Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales
|
||||
// Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
// MX Scale tile distributions for loading pre-packed int32_t from global memory
|
||||
// Packed layout: [M/MXdlPack, K/32/KXdlPack] of int32_t
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
@@ -145,21 +153,29 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
constexpr index_t MPerXdl = WarpTile::at(number<0>{});
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl; // 64/16 = 4 threads in K dimension
|
||||
constexpr index_t K_Lane = get_warp_size() / MPerXdl;
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * MPerXdl);
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
|
||||
|
||||
// Effective pack sizes: fall back to 1 when iteration count < pack size
|
||||
constexpr index_t MXdlPackEff =
|
||||
(MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1;
|
||||
constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff;
|
||||
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<NWarp>, // repeat over MWarps
|
||||
tuple<sequence<MIterPerWarp, MWarp, MPerXdl>, // M dimension (first)
|
||||
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>, // <KIterPerWarp, MIterPerWarp, KPerLane>
|
||||
sequence<0, 0, 2>>{});
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp_packed, MWarp, MPerXdl>,
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -169,27 +185,35 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
using BlockWarps = typename BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl; // 64/16 = 4 threads in K dimension
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MWarp = BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = BlockWarps::at(number<1>{});
|
||||
constexpr index_t NPerXdl = WarpTile::at(number<1>{});
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K_Lane = get_warp_size() / NPerXdl;
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * NPerXdl);
|
||||
|
||||
constexpr index_t KPerXdl = WarpTile::at(number<2>{});
|
||||
constexpr index_t KIterPerWarp = KPerBlock / KPerXdl;
|
||||
constexpr index_t KPerLane = KPerXdl / BlockScaleSize / K_Lane;
|
||||
|
||||
// Effective pack sizes: fall back to 1 when iteration count < pack size
|
||||
constexpr index_t NXdlPackEff =
|
||||
(NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1;
|
||||
constexpr index_t KXdlPackEff =
|
||||
(KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1;
|
||||
|
||||
constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff;
|
||||
constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<MWarp>, // repeat over MWarps
|
||||
tuple<sequence<NIterPerWarp, NWarp, NPerXdl>, // N dimension (first)
|
||||
sequence<KIterPerWarp, K_Lane, KPerLane>>, // K dimension (second)
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // <MWarp, NWarp>, <K_Lane, MPerXdl>
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>, // <KIterPerWarp, NIterPerWarp, KPerLane>
|
||||
sequence<0, 0, 2>>{});
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp_packed, NWarp, NPerXdl>,
|
||||
sequence<KIterPerWarp_packed, K_Lane, KPerLane>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -134,7 +134,12 @@ struct ABQuantGemmPipelineAgBgCrEightWaves : public BaseGemmPipelineAgBgCrCompV3
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
// We are not storing the original packed type in LDS, so we need to multiply the smem size
|
||||
// by the packed size.
|
||||
constexpr index_t smem_size_a = Policy::template GetSmemSizeA<Problem>() * APackedSize;
|
||||
constexpr index_t smem_size_b = Policy::template GetSmemSizeB<Problem>() * BPackedSize;
|
||||
|
||||
return 2 * (smem_size_a + smem_size_b);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWaves\n"; }
|
||||
|
||||
@@ -45,6 +45,55 @@ class TestMxGemmUtil : public ::testing::Test
|
||||
using ScaleM = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
using ScaleN = ck_tile::MXScalePointer<ScaleType, 1, 32>;
|
||||
|
||||
// Pack [MN, K/32] e8m0_t scales into [MN/MNPack, K/32/KPack] int32_t
|
||||
// Each int32_t contains MNPack * KPack e8m0_t values with byte layout matching
|
||||
// the GPU tile distribution: values are XdlMNThread apart in M and XdlKThread apart in K.
|
||||
// byte[ik * MNPack + imn] = e8m0 at strided (mn, k) position
|
||||
// kLast=true for A scales (layout [M, K/32]), kLast=false for B scales (layout [K/32, N])
|
||||
template <ck_tile::index_t MNPack = 2,
|
||||
ck_tile::index_t KPack = 2,
|
||||
ck_tile::index_t XdlMNThread = 16,
|
||||
ck_tile::index_t XdlKThread = 4>
|
||||
static auto packScalesMNxK(const ck_tile::HostTensor<ck_tile::e8m0_t>& src, bool kLast)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const ck_tile::index_t MN = kLast ? src_lengths[0] : src_lengths[1];
|
||||
const ck_tile::index_t K_scale = kLast ? src_lengths[1] : src_lengths[0];
|
||||
const ck_tile::index_t MN_packed = MN / MNPack;
|
||||
const ck_tile::index_t K_packed = K_scale / KPack;
|
||||
const ck_tile::index_t total_packed = MN_packed * K_packed;
|
||||
|
||||
std::vector<int32_t> packed(total_packed);
|
||||
|
||||
for(ck_tile::index_t packed_mn = 0; packed_mn < MN_packed; packed_mn++)
|
||||
{
|
||||
for(ck_tile::index_t packed_k = 0; packed_k < K_packed; packed_k++)
|
||||
{
|
||||
int32_t val = 0;
|
||||
ck_tile::index_t mn_lane = packed_mn % XdlMNThread;
|
||||
ck_tile::index_t mn_group = packed_mn / XdlMNThread;
|
||||
ck_tile::index_t k_lane = packed_k % XdlKThread;
|
||||
ck_tile::index_t k_group = packed_k / XdlKThread;
|
||||
for(ck_tile::index_t ik = 0; ik < KPack; ik++)
|
||||
{
|
||||
for(ck_tile::index_t imn = 0; imn < MNPack; imn++)
|
||||
{
|
||||
ck_tile::index_t byteIdx = ik * MNPack + imn;
|
||||
ck_tile::index_t orig_mn =
|
||||
mn_group * XdlMNThread * MNPack + imn * XdlMNThread + mn_lane;
|
||||
ck_tile::index_t orig_k =
|
||||
k_group * XdlKThread * KPack + ik * XdlKThread + k_lane;
|
||||
|
||||
ck_tile::e8m0_t v = kLast ? src(orig_mn, orig_k) : src(orig_k, orig_mn);
|
||||
val |= (static_cast<int32_t>(v.get()) << (byteIdx * 8));
|
||||
}
|
||||
}
|
||||
packed[packed_mn * K_packed + packed_k] = val;
|
||||
}
|
||||
}
|
||||
return packed;
|
||||
}
|
||||
|
||||
void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234)
|
||||
{
|
||||
const ck_tile::index_t scale_k_size = K / 32;
|
||||
@@ -75,17 +124,43 @@ class TestMxGemmUtil : public ::testing::Test
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_a_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.001f, 10.f, seed++}(scale_b_host);
|
||||
|
||||
// Compute effective XdlPack sizes based on GemmConfig tile dimensions
|
||||
constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t NPerXdl = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t KPerXdl = GemmConfig::K_Warp_Tile;
|
||||
constexpr ck_tile::index_t MIterPerWarp =
|
||||
GemmConfig::M_Tile / (GemmConfig::M_Warp * MPerXdl);
|
||||
constexpr ck_tile::index_t NIterPerWarp =
|
||||
GemmConfig::N_Tile / (GemmConfig::N_Warp * NPerXdl);
|
||||
constexpr ck_tile::index_t KIterPerWarp = GemmConfig::K_Tile / KPerXdl;
|
||||
|
||||
constexpr ck_tile::index_t MXdlPackEff =
|
||||
(MIterPerWarp >= 2 && MIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t NXdlPackEff =
|
||||
(NIterPerWarp >= 2 && NIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
constexpr ck_tile::index_t KXdlPackEff =
|
||||
(KIterPerWarp >= 2 && KIterPerWarp % 2 == 0) ? 2 : 1;
|
||||
|
||||
constexpr ck_tile::index_t XdlMNThread = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
// Pack scales into int32_t for GPU consumption
|
||||
auto scale_a_packed =
|
||||
packScalesMNxK<MXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_a_host, true);
|
||||
auto scale_b_packed =
|
||||
packScalesMNxK<NXdlPackEff, KXdlPackEff, XdlMNThread, XdlKThread>(scale_b_host, false);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_packed.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_packed.size() * sizeof(int32_t));
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_dev_buf.ToDevice(b_host.data());
|
||||
c_dev_buf.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_host.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_host.data());
|
||||
scale_a_dev_buf.ToDevice(scale_a_packed.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_packed.data());
|
||||
|
||||
ScaleM scale_m(reinterpret_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()));
|
||||
ScaleN scale_n(reinterpret_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
Reference in New Issue
Block a user