mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297)
* [fix] align v3 gufusion pipeline
* fix device kernel selection.
* Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
* experimental optimization for scale load in blkscale gemm
* Add asm for no-loop v3_128x128x128
* fix bugs
* tune fp8 example
* Update v1_128x128x128 to 2x2 instead of 4x1
* wip
* add warmup to asm launch
* wip2
* 16x16 function merged to moe
* temp save, a performant version.
* wip3
* Update .co binary to 16x16
* 16x16x128 correct; 64x64x128 failed
* update
* use mem_op::set when topk=1
* add mx fp8 b_preshuffle support, function not yet tested.
* Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints
* some fixes
* fix update
* remove some unnecessary hacky; enable 256x256x256 tilesize
* update for function debug
* Add pipeline v3. Have some runtime issue and register spill
* Fix pipe v3 correctness issue
* remove unnecessary hacky
* clang format
* fix a bug
* fix the bug, functional test passed
* tempsave; buggy at passed 4 e8m0 to scaled mfma
* added fp4_bpreshuffle example, build failures
* fixed some bugs
* implement shuffled scale mxfp4gemm, blocker: opsel not effect
* hotfix
* fix bugs, build passed
* (M, N, K)=(128, 128, 128) function failed.
* temp save for gemm1. Function not ready
* fix compile error. Gemm2 pass. Gemm1 WIP
* fix bug for a lds read
* update moe
* Compile pass. Gemm1 function WIP
* update moe
* fix fp8; fix even/odd
* tempsave
* update moe
* Revert "update"
This reverts commit c7d79dcb672616d9bc0fd9958f714fc80e7c84fd.
* Revert "use mem_op::set when topk=1"
This reverts commit 8c7772860735001a51421e7b6d0a28f6676d6c40.
* Add v3 128x128x128_4x4_16x16.co for gfx950
* temp cmake flag suppression for aiter test
* add code for mxfp4 gemm, blockscale not supported yet
* gemm1 up-only pass. GU WIP
* function pass with inline asm hacky
* revert unexpected file change
* updated and build passed
* update CE elementOP
* added code for debug
* Gemm1 GUFusion function pass. Perf WIP
* Fix fp8/bf8; remove duplicated code
* disable the scheduler in v3; bring it back when compiler feature ready.
* update moe v1 pipeline
* Add gemm1 v1 32x128x128
* remove schedule barrier
* updated
* Fix fp8/bf8 B-row
* mfma using asm, device result correct, host result need to check
* gemm1 v3 64x128x128 debug
* fix cpu ref
* a/b thread_desc stride fix
* Use random scale for init1
* 16x16x128 input size blockscale function passed
* fix blockscale gemm bug
* tempsave. Almost all instances passed.
* v1 fix for mi350.
* temp save
* debug save
* update debug
* fix the bug, 128x128x256 tile function passed
* v3
* rename moe block selector and pipeline
* Add gemm1 v1
* Add gemm1 v1 to selector
* added mx moe block v3 support, function passed
* compile error fix
* Improve the pipeline
* Pack e8m0 as int32_t
* v1 compile pass. Function not ready
* debug synchronize issue over different GPU/ROCm
* minor fix
* Add profiler filter
* Add f4 ckProfiler
* Fix example compile error
* Add f4 profiler examples
* tempsave
* v1 function pass.
* v3 function pass
* align file and function name
* mx_moe_fp4 ready for aiter with clang-format.
* modify the way we represent fp4
* generalize the pipeline scheduling.
* init moe mx f4 scale shuffle
* Cmakelist diable compiler-bound flags
* mx_fp4 default parameter change
* Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler.
* update code
* tempsave; modify the way we represent fp4
* generalize the pipeline scheduling.
* Add gemm1 gfx942 .co support
* updated code, build passed.
* Update gemm2 asm with latest compiler flag
* Fix mx f4 ckProfiler
* Fix blockwise gemm mx v1
* lds conflict free + buffer load lds
* Add gemm2 v3 64x128x128
* fix a, b scale loading bugs, a, b scale loading now correctly
* Add gemm2 v3 64x128x128
* commit with debug info
* fix fp4 profiler
* Add mx fp4 pileline v1 instances
* Fix v2 topk_weight cal. Add silu asm.
* v2 tok_weight WIP
* init mx fp4 B no preshuffle version
* tempsave. compile pass, function wrong
* enable fp4 moe no weigth preshuffle, function pass
* update the TFlops calculation in the example
* Add gemm2 64x128x128 asm. Fix BF16 ref.
* fix 2 typos in fp4_preshuffle
* Better kernel selection in device classes
* correct preShuffleBuffer
we should used packed k to do shuffle.
* lds conflict free + buffer load lds
* optimize offset math in dma
* Fix fp4 ckProfiler
* Fix MX MFMA tests
* fix f4 pipeline issues
* gemm1 func pass
* update mx moe gemm1_bns tile size to 64x128x256
* update mx moe gemm1 gemm2 TF and BW calculation
* fix typo
* temp save
* Fix example_gemm_mx build
* rename the block pipeline
* correct a typo in tail
* Add rotating to mx examples
* fix the correctness issue
* Fix v1; use M padding
* Add NT flag to B/BScale buffer
* Merge gemm_mx_common.hpp
* temp save, 4.4~4.5
* Fix 'Merge gemm_mx_common.hpp'
* refactor the pipeline
* Pad the M for scale buffer unconditionaly
* update MX moe GEMM1 hotloopscheduling
* change the gemm1 tile from 64x128x128 to 128x64x128
* Unconditional Ascale padding
* Pad shuffled a scale only
* pad ascale
* add vmcnt guard for async copy
* Profiler add f4 wp
* Merge preshuffle device
* Add more fp4 wp instances
* Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format
* Clang-format after 2 merges
* Remove rocm6.3 workaround flags and macro
* Fix fp8 config
* Fix bf8 config
* flag and barrier fix for copmiler branch MainOpSelV3
* Add fp8 profiler instances
* Remove debug infos; Enable flags for blockscale f8
* No asm ver. for merging moe blocksale fp8 into mainline
* update the flag name for f8blockscale
* recover example
* fix performance bug of bpreshuffle f8 gemm
* clang format, remove single rate mfma restriction for f8
* remove single rate mfma restriction for f8 blockscale gemm
* Fix moe blockscale gemm1 barrier 0x800 for new compiler
* add pipeline v1 for MOE Gemm2
* Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns
* Fix OOB; add MB96 instances
* remove unnecessary files
* fix the cmake issue
* Enable splitk for mxfp4; clang format;
* Generate random tensor values with multiple threads
* Use packed_size_v for A/BPackedSize
* Fix warning
* Fix target_compile_options for disabled target on gfx942
* fix moe pki4 on gfx950
* doc the kGroup definition
* Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale)
* Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example
* Fix unknown compiler flag
* fix two failed examples.
* fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16
* workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major
* fix test_gemm_splitk;
* Fix compile for mx_mfma_op
* add mfma selection logic for multipled_v3
* Clean up
* Fix device gemm mx link error
* improve the global atomic pattern
* Revert unnecessary copyright updates
* restore minimum_occupancy logic
* Avoid data race in moe gemm2 ref
* Build fp8 gemm_multiply_multiply and moe only on gfx94/95
* update the instance in device_mx_gemm
* Resolve comments
* Copyright 2025
* Remove unused code
* fix library linking issue
---------
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: valarLip <340077269@qq.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
[ROCm/composable_kernel commit: 37554c31e8]
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -18,6 +19,7 @@
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/ranges.hpp"
|
||||
#include "ck/library/utility/thread.hpp"
|
||||
|
||||
template <typename Range>
|
||||
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
@@ -512,6 +514,72 @@ struct Tensor
|
||||
}
|
||||
}
|
||||
|
||||
// Generate random values with multiple threads. Guaranteed to give the same sequence with any
|
||||
// number of threads provided.
|
||||
template <typename Distribution = std::uniform_real_distribution<float>,
|
||||
typename Mapping = ck::identity,
|
||||
typename Generator = std::minstd_rand>
|
||||
void GenerateTensorDistr(Distribution dis = {0.f, 1.f},
|
||||
Mapping fn = {},
|
||||
const Generator g = Generator(0), // default seed 0
|
||||
std::size_t num_thread = -1)
|
||||
{
|
||||
using ck::math::integer_divide_ceil;
|
||||
using ck::math::min;
|
||||
if(num_thread == -1ULL)
|
||||
num_thread = min(ck::get_available_cpu_cores(), 80U); // max 80 threads
|
||||
// At least 2MB per thread
|
||||
num_thread = min(num_thread, integer_divide_ceil(this->GetElementSpaceSize(), 0x200000));
|
||||
constexpr std::size_t BLOCK_BYTES = 64;
|
||||
constexpr std::size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T);
|
||||
|
||||
const std::size_t num_blocks = integer_divide_ceil(this->GetElementSpaceSize(), BLOCK_SIZE);
|
||||
const std::size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread);
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_thread - 1);
|
||||
const auto dst = const_cast<T*>(this->mData.data());
|
||||
const auto element_space_size = this->GetElementSpaceSize();
|
||||
for(int it = num_thread - 1; it >= 0; --it)
|
||||
{
|
||||
std::size_t ib_begin = it * blocks_per_thread;
|
||||
std::size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks);
|
||||
|
||||
auto job = [=]() {
|
||||
auto g_ = g; // copy
|
||||
auto dis_ = dis; // copy
|
||||
g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v<T>);
|
||||
auto t_fn = [&]() {
|
||||
if constexpr(ck::packed_size_v<T> == 1)
|
||||
return ck::type_convert<T>(fn(dis_(g_)));
|
||||
else if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
|
||||
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(
|
||||
ck::float2_t{ck::type_convert<float>(fn(dis_(g_))),
|
||||
ck::type_convert<float>(fn(dis_(g_)))})};
|
||||
else
|
||||
static_assert(false, "Unsupported packed size for T");
|
||||
};
|
||||
|
||||
std::size_t ib = ib_begin;
|
||||
for(; ib < ib_end - 1; ++ib)
|
||||
ck::static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) {
|
||||
constexpr size_t iw = iw_.value;
|
||||
dst[ib * BLOCK_SIZE + iw] = t_fn();
|
||||
});
|
||||
for(std::size_t iw = 0; iw < BLOCK_SIZE; ++iw)
|
||||
if(ib * BLOCK_SIZE + iw < element_space_size)
|
||||
dst[ib * BLOCK_SIZE + iw] = t_fn();
|
||||
};
|
||||
|
||||
if(it > 0)
|
||||
threads.emplace_back(std::move(job));
|
||||
else
|
||||
job(); // last job run in the main thread
|
||||
}
|
||||
for(auto& t : threads)
|
||||
t.join();
|
||||
}
|
||||
|
||||
template <typename... Is>
|
||||
std::size_t GetOffsetFromMultiIndex(Is... is) const
|
||||
{
|
||||
|
||||
@@ -163,6 +163,18 @@ struct GeneratorTensor_1<ck::pk_i4_t>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::e8m0_bexp_t>
|
||||
{
|
||||
float value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ck::e8m0_bexp_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::e8m0_bexp_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
|
||||
25
include/ck/library/utility/thread.hpp
Normal file
25
include/ck/library/utility/thread.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __linux__
|
||||
#include <sched.h>
|
||||
#endif
|
||||
#include <thread>
|
||||
namespace ck {
|
||||
inline unsigned int get_available_cpu_cores()
|
||||
{
|
||||
#if defined(__linux__)
|
||||
cpu_set_t cpu_set;
|
||||
if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0)
|
||||
{
|
||||
unsigned int cpu_count = CPU_COUNT(&cpu_set);
|
||||
if(cpu_count > 0)
|
||||
return cpu_count;
|
||||
}
|
||||
#endif
|
||||
// Fallback if sched_getaffinity unavailable or fails
|
||||
return std::thread::hardware_concurrency();
|
||||
}
|
||||
} // namespace ck
|
||||
@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -153,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
@@ -290,12 +291,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -388,12 +391,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -477,12 +483,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// B VGPR->VGPR dequant
|
||||
@@ -588,7 +596,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
|
||||
@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -154,9 +155,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
@@ -298,12 +299,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -382,12 +385,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -458,12 +464,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -556,7 +565,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
|
||||
@@ -0,0 +1,952 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
|
||||
|
||||
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
|
||||
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
|
||||
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
|
||||
|
||||
// constexpr auto num_dsread_a_mfma =
|
||||
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
|
||||
constexpr auto num_total_stages = MRepeat;
|
||||
|
||||
// Group num_mfma_perstage num_ds_read_a_perstage
|
||||
// since we want to reuse a local register buffer
|
||||
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
|
||||
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
|
||||
|
||||
constexpr auto num_ds_read_a_mfma_perstage =
|
||||
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
|
||||
|
||||
constexpr auto num_ds_read_a_prefetch_stages = 2;
|
||||
|
||||
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
|
||||
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
|
||||
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
|
||||
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_stages_more =
|
||||
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
|
||||
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
|
||||
(num_total_stages - 2)) *
|
||||
((num_total_stages - 2));
|
||||
|
||||
constexpr auto buffer_load_b_stages =
|
||||
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
|
||||
? num_buffer_load_inst_b / buffer_load_perstage_more
|
||||
: (buffer_load_stages_more +
|
||||
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
|
||||
buffer_load_perstage_less);
|
||||
|
||||
constexpr auto buffer_load_a_stages =
|
||||
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
|
||||
|
||||
constexpr auto buffer_load_issue_point_b = 0;
|
||||
constexpr auto buffer_load_issue_point_interval_more =
|
||||
num_mfma_perstage / buffer_load_perstage_more;
|
||||
constexpr auto buffer_load_issue_point_interval_less =
|
||||
num_mfma_perstage / buffer_load_perstage_less;
|
||||
constexpr auto ds_write_issue_point = 0;
|
||||
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
|
||||
|
||||
// B global read
|
||||
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
if constexpr(((i < buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_more ==
|
||||
buffer_load_issue_point_b)) ||
|
||||
((i >= buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_less ==
|
||||
buffer_load_issue_point_b)))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// A global read + A local write
|
||||
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_more ==
|
||||
ds_write_issue_point)) ||
|
||||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_less ==
|
||||
ds_write_issue_point)))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_more ==
|
||||
buffer_load_issue_point_a)) ||
|
||||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
|
||||
(imfma % buffer_load_issue_point_interval_less ==
|
||||
buffer_load_issue_point_a)))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// lds synchronization, prefetch next loop local A
|
||||
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Stage>
|
||||
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b =
|
||||
MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
|
||||
|
||||
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
|
||||
|
||||
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
|
||||
constexpr auto staged_num_mfma = num_mfma / MRepeat;
|
||||
|
||||
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
|
||||
|
||||
if constexpr(stage.value == 0)
|
||||
{
|
||||
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
|
||||
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
|
||||
constexpr auto staged_num_mfma_per_buffer_load_b =
|
||||
staged_num_mfma / num_buffer_load_inst_b;
|
||||
// B global
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
|
||||
static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) {
|
||||
ignore = ibuf_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(stage.value == 1)
|
||||
{
|
||||
constexpr auto staged_num_mfma_per_ds_write_a =
|
||||
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
|
||||
|
||||
constexpr auto stage_more_mfma =
|
||||
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
|
||||
|
||||
// A local write
|
||||
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
|
||||
if constexpr(i_inst.value < stage_more_mfma)
|
||||
{
|
||||
if(i_inst.value < staged_num_ds_read_inst_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(i_inst.value < staged_num_ds_read_inst_a)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
|
||||
}
|
||||
}
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// A local Read
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto EpilogueScheduler_2()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
|
||||
|
||||
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
|
||||
constexpr auto staged_num_mfma = num_mfma / MRepeat;
|
||||
|
||||
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
|
||||
|
||||
// A local Read
|
||||
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
|
||||
ignore = i_inst;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_buf;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
// Global prefetch A1 B1
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// // Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
|
||||
|
||||
// // Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, 2, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2,
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value == MRepeat - 2)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(local_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(m0.value == (MRepeat - 1))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(local_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(mfma_reg_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
|
||||
2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
if constexpr(m0.value == (MRepeat - 2))
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(m0.value == MRepeat - 1)
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(Number<(m0 + 2) % MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
Number<k0 * KGroup + kg0>{},
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value < (MRepeat - 2))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(
|
||||
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I1),
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
|
||||
I0,
|
||||
I0,
|
||||
k0,
|
||||
I0,
|
||||
Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(m0.value < (MRepeat - 2))
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(
|
||||
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(
|
||||
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
// Reduce the vgpr usage here.
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I2, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,919 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
BScaleThreadTransfer& b_scale_thread_copy_up,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleGridBuffer& b_scale_grid_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
ignore = a_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf_up;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 0
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,155 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
|
||||
// must be used for compute
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
|
||||
{
|
||||
|
||||
// Hardware MX GEMM pipeline
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
;
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,813 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,6 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
|
||||
@@ -171,26 +172,54 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
|
||||
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
|
||||
if constexpr(std::is_same<ADataType, BDataType>::value)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp"
|
||||
namespace ck {
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
constexpr auto BlockGemmBlockScaleBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
#if 0
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v2<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
#endif
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
|
||||
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,864 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
|
||||
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename AScaleThreadTransferStep,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadDesc,
|
||||
typename BScaleThreadTransfer,
|
||||
typename BScaleThreadTransferStep>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
const AScaleThreadDesc& a_scale_thread_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const AScaleThreadTransferStep& a_scale_thread_copy_step,
|
||||
// BScaleThreadCopy
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
const BScaleThreadDesc& b_scale_thread_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block +
|
||||
k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<
|
||||
b_thread_desc_.CalculateOffset(make_tuple(
|
||||
n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
|
||||
[&](auto t) {
|
||||
using pk_fma_type =
|
||||
typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) =
|
||||
__builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec
|
||||
.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf
|
||||
.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,186 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp"
|
||||
namespace ck {
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v2<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
#endif
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
|
||||
BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,854 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MScaleBlock,
|
||||
index_t NScaleBlock,
|
||||
index_t KScaleBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MScaleBlock,
|
||||
NScaleBlock,
|
||||
KScaleBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
true>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KGroup;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack / KGroup;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat * KGroup;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
|
||||
|
||||
// B global
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A global
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
});
|
||||
|
||||
// A local
|
||||
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename AScaleThreadTransferStep,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadDesc,
|
||||
typename BScaleThreadTransfer,
|
||||
typename BScaleThreadTransferStep>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
const AScaleThreadDesc& a_scale_thread_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const AScaleThreadTransferStep& a_scale_thread_copy_step,
|
||||
// BScaleThreadCopy
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
const BScaleThreadDesc& b_scale_thread_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block +
|
||||
k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<
|
||||
b_thread_desc_.CalculateOffset(make_tuple(
|
||||
n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
|
||||
[&](auto t) {
|
||||
using pk_fma_type =
|
||||
typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) =
|
||||
__builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec
|
||||
.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf
|
||||
.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, KGroup, 1>{}([&](auto kg0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k0_k1_k2,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[Number<cscale_offset>{}];
|
||||
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
I0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// MRepeat MWave MLane KRepeat KLane KPack
|
||||
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,130 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
|
||||
// must be used for compute
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmMXNBSPipeline_Selector()
|
||||
{
|
||||
|
||||
// Hardware MX GEMM pipeline
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,664 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 1
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_m3_k;
|
||||
using Base::b_block_desc_n0_n1_n2_n3_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::APackedSize;
|
||||
using Base::BMmaKStride;
|
||||
using Base::BPackedSize;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::KXdlPack;
|
||||
using Base::MXdlPack;
|
||||
using Base::NXdlPack;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple5 = typename Base::Tuple5;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun =
|
||||
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
using mx_scale_t = e8m0_bexp_t;
|
||||
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
|
||||
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
|
||||
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
|
||||
"A scale pack data type too large!");
|
||||
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
|
||||
"B scale pack data type too large!");
|
||||
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
|
||||
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read block data in chunks to assemble correct thread vectors
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto b_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_buf[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
|
||||
ikxdl * NXdlPack + inxdl>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(
|
||||
Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, k0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, k0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(0, I1, 0));
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
|
||||
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
I0,
|
||||
Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(Number<m0 / MXdlPack>{},
|
||||
I0,
|
||||
Number<m0 % MXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read block data in chunks to assemble correct thread vectors
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto b_k_step_chunk =
|
||||
k_step +
|
||||
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
I0,
|
||||
Number<b_k_step_chunk>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(Number<n0 / NXdlPack>{},
|
||||
I0,
|
||||
Number<n0 % NXdlPack>{},
|
||||
k,
|
||||
Number<chunk * KThreadChunk>{}),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops "
|
||||
"per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_buf[Number<a_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_buf[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
|
||||
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
using mfma_scale_input_type_a =
|
||||
typename vector_type<AScaleDataType,
|
||||
a_scale_thread_vec_size>::type;
|
||||
using mfma_scale_input_type_b =
|
||||
typename vector_type<BScaleDataType,
|
||||
b_scale_thread_vec_size>::type;
|
||||
|
||||
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, imxdl, inxdl, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
|
||||
ikxdl * NXdlPack + inxdl>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_a>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec
|
||||
.template AsType<mfma_scale_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat / MXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat / NXdlPack>{},
|
||||
Number<KRepeat / KXdlPack>{},
|
||||
Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -532,6 +532,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
@@ -560,8 +563,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#ifndef __HIPCC_RTC__
|
||||
@@ -149,6 +149,52 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockSize,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideAScale,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideBScale,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -60,6 +60,49 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,507 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
// unconditional 2 to remove agpr usage
|
||||
constexpr index_t minimum_occupancy = 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
|
||||
// KPerBlock)
|
||||
// {
|
||||
// return false;
|
||||
// }
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Padding to release this restriction
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,584 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct DeviceMoeGemmBlockScale
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm = GridwiseMoeGemmBlockScale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
|
||||
4 * (2) * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v2 support now");
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
// index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1, // KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmm"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,571 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemmMX<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
ScaleBlockSize,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
|
||||
APackedSize / BlockSize / 4 *
|
||||
(1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) /
|
||||
BPackedSize / BlockSize / 4 * (2) *
|
||||
(IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v3 support now");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmmMx"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,540 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemmMXBNS<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
ScaleBlockSize,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Check if this is the right algorithm for minimum_occupancy
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
|
||||
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
|
||||
? 2
|
||||
: 1
|
||||
: 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v3 support now");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmmMx"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -153,9 +153,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -168,9 +168,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
@@ -1192,7 +1190,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
@@ -1200,7 +1197,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
@@ -1629,7 +1625,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
|
||||
|
||||
@@ -1637,7 +1632,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// dummy
|
||||
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
|
||||
|
||||
// A matrix blockwise copy
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1221,7 +1221,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
}
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
@@ -1232,7 +1231,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
@@ -2123,6 +2121,58 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
// calculate C grid descriptor
|
||||
constexpr auto DWORD_BYTES = 4;
|
||||
constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType);
|
||||
|
||||
constexpr auto CShuffleBlockTransferClusterLengths = [&]() {
|
||||
if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{};
|
||||
}
|
||||
// Atomic operation
|
||||
else
|
||||
{
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(i == 3)
|
||||
{
|
||||
return Number<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
|
||||
.At(i) *
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock /
|
||||
atomic_vector_size>{};
|
||||
}
|
||||
else if constexpr(i == 1)
|
||||
{
|
||||
return Number<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
|
||||
.At(i) /
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock *
|
||||
atomic_vector_size>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}
|
||||
.At(i)>{};
|
||||
}
|
||||
},
|
||||
Number<4>{});
|
||||
}
|
||||
}();
|
||||
|
||||
constexpr auto CShuffleBlockTransferScalarPerVector = [&]() {
|
||||
if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set)
|
||||
{
|
||||
return CShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
return atomic_vector_size;
|
||||
}
|
||||
}();
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
@@ -2132,15 +2182,15 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle
|
||||
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
|
||||
1,
|
||||
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
decltype(CShuffleBlockTransferClusterLengths),
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
|
||||
@@ -183,27 +183,28 @@ struct GridwiseMoeGemm
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
|
||||
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
|
||||
@@ -262,7 +263,7 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
__host__ __device__ static auto CalculateBK0Shuffled(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KLane * KPack);
|
||||
return math::integer_divide_ceil(K, KLane * KPack / KGroup);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
@@ -404,7 +405,7 @@ struct GridwiseMoeGemm
|
||||
|
||||
__host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
|
||||
{
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack>{};
|
||||
constexpr index_t NkSwizzleNumber = Number<warpSize * KPack / KGroup>{};
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
|
||||
make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
|
||||
@@ -1314,7 +1315,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -1360,7 +1361,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -1899,7 +1900,8 @@ struct GridwiseMoeGemm
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
|
||||
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
|
||||
if(expert_block_id * MPerBlock >= max_token_id)
|
||||
return;
|
||||
@@ -1908,12 +1910,13 @@ struct GridwiseMoeGemm
|
||||
const auto block_mn = [&]() -> std::pair<int, int> {
|
||||
if constexpr(NSwizzle)
|
||||
{
|
||||
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
|
||||
const index_t prefix_block = ecnt_prefix * problem.NBlock;
|
||||
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
|
||||
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
|
||||
const index_t bid_new = blockIdx.x - prefix_block;
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(
|
||||
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
|
||||
const index_t prefix_block = ecnt_prefix * problem.NBlock;
|
||||
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
|
||||
const index_t expert_swizzle =
|
||||
ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
|
||||
const index_t bid_new = blockIdx.x - prefix_block;
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(
|
||||
bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
|
||||
const index_t mid =
|
||||
__builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
|
||||
@@ -1924,9 +1927,9 @@ struct GridwiseMoeGemm
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
}
|
||||
}();
|
||||
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
|
||||
const index_t token0 =
|
||||
__builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
|
||||
|
||||
@@ -1938,11 +1941,9 @@ struct GridwiseMoeGemm
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
|
||||
token0 >= problem.NumTokens)
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<IndexType, AMRepeats>
|
||||
gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
@@ -1952,7 +1953,8 @@ struct GridwiseMoeGemm
|
||||
}
|
||||
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
|
||||
});
|
||||
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
|
||||
const index_t expert_stride =
|
||||
__builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
|
||||
|
||||
// N0, K0, Blocksize*KPack
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
@@ -2025,7 +2027,7 @@ struct GridwiseMoeGemm
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack * (get_thread_local_1d_id() % warpSize)));
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
// Cast after lds
|
||||
@@ -2042,24 +2044,76 @@ struct GridwiseMoeGemm
|
||||
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
|
||||
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
|
||||
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
float,
|
||||
c_thread_buf.num_of_v_,
|
||||
c_thread_buf.s_per_v,
|
||||
true>
|
||||
c_thread_buf_fp32;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid_up + expert_id * expert_stride / BPackedSize,
|
||||
b_grid_desc_bpreshuffled.GetElementSpaceSize());
|
||||
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bpreshuffled),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
|
||||
Sequence<1, 2, 0, 3>,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_grid_desc_bpreshuffled,
|
||||
make_multi_index(n_block_data_idx_on_grid,
|
||||
get_warp_local_1d_id() % NWave,
|
||||
0,
|
||||
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_blockwise_copy_up,
|
||||
b_grid_buf,
|
||||
b_grid_buf_up,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
c_thread_buf_up,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bpreshuffled,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
}
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -2087,6 +2141,185 @@ struct GridwiseMoeGemm
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
|
||||
// mul scales
|
||||
const float* p_sorted_weights_0 = p_ds_grid[I0];
|
||||
const float* p_scale_b = p_ds_grid[I1];
|
||||
|
||||
static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
|
||||
static_assert(M4 == 4);
|
||||
const index_t m1 = get_warp_local_1d_id() / NWave;
|
||||
const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
|
||||
|
||||
if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr)
|
||||
{
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
constexpr index_t scale_stride = (IsInputGemm ? 2 : 1);
|
||||
p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock +
|
||||
get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_scale_b += expert_id;
|
||||
}
|
||||
|
||||
vector_type<int32_t, 4> scale_token_ids;
|
||||
vector_type<float, 4> topk_weights;
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant];
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
scale_token_ids =
|
||||
*c_style_pointer_cast<const vector_type<int32_t, M4>*>(
|
||||
p_sorted_token_ids + m_pos);
|
||||
}
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
float scale_a = [&]() {
|
||||
if constexpr(PerTokenQuant)
|
||||
{
|
||||
index_t fused_token = scale_token_ids.AsType<index_t>()[m4];
|
||||
const index_t token_offset = fused_token & 0xffffff;
|
||||
return token_offset < problem.NumTokens
|
||||
? p_sorted_weights_0[token_offset]
|
||||
: 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return p_sorted_weights_0[0];
|
||||
}
|
||||
}();
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
const float scale_up =
|
||||
p_scale_b[(n0 * NWave * NPerXdl + problem.N) *
|
||||
PerTokenQuant];
|
||||
float gate = scale_a * scale_b * c_thread_buf[cidx];
|
||||
float up = scale_a * scale_up * c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) =
|
||||
scale_a * scale_b * c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
|
||||
topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
vector_type<float, 4> topk_weights; // for gemm2 only
|
||||
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
|
||||
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
|
||||
static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
|
||||
const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
|
||||
m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
|
||||
p_ds_grid[I2] + m_pos);
|
||||
}
|
||||
static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
|
||||
constexpr index_t c_offset =
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, m2 * M4 + m4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
|
||||
if constexpr(IsInputGemm) // gu fusion
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weights.AsType<float>()[m4];
|
||||
up = up * topk_weights.AsType<float>()[m4];
|
||||
}
|
||||
tensor_operation::element_wise::Gelu{}(gate, gate);
|
||||
c_thread_buf_fp32(cidx) = gate * up;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
c_thread_buf_fp32(cidx) = topk_weights.AsType<float>()[m4] *
|
||||
c_thread_buf_fp32[cidx];
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
|
||||
@@ -2184,18 +2417,8 @@ struct GridwiseMoeGemm
|
||||
|
||||
const auto ds_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
const DDataType* ptr_ = p_ds_grid[i];
|
||||
// hack logic here to support different kind of strides. todo fix it.
|
||||
// ascale t, 1; bscale E, N, 1, move ptr to E
|
||||
// if(i.value == 1)
|
||||
// {
|
||||
// ptr_ +=
|
||||
// expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N :
|
||||
// 1);
|
||||
// }
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
@@ -2271,7 +2494,6 @@ struct GridwiseMoeGemm
|
||||
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
@@ -2310,7 +2532,7 @@ struct GridwiseMoeGemm
|
||||
block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
|
||||
static_for<0, EMRepeats, 1>{}([&](auto m0) {
|
||||
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
IndexType token_offset = fused_token & 0xffffff;
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
@@ -2323,7 +2545,7 @@ struct GridwiseMoeGemm
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_thread_buf_fp32,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
2652
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp
Normal file
2652
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2849
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp
Normal file
2849
include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -580,11 +580,6 @@ struct ThreadwiseTensorSliceTransfer_v2_gather
|
||||
});
|
||||
});
|
||||
|
||||
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
|
||||
// blockIdx.y,
|
||||
// threadIdx.x,
|
||||
// dst_buf(Number<0>{}));
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
|
||||
@@ -1146,15 +1146,6 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
|
||||
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, pk_i4_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
|
||||
{
|
||||
|
||||
@@ -881,11 +881,6 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB>
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 1
|
||||
|
||||
#ifndef BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
|
||||
#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 0
|
||||
#endif
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
|
||||
struct intrin_mfma_scale_f32_16x16x128f8f6f4;
|
||||
@@ -893,48 +888,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4;
|
||||
template <index_t OpselA, index_t OpselB>
|
||||
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
{
|
||||
|
||||
#define V_MFMA_SCALE_F32_16X16X128_F8F6F4(OPF_F8F6F4_CTRL_A, \
|
||||
OPF_F8F6F4_CTRL_B, \
|
||||
F8F6F4_VEC_TYPE_A, \
|
||||
F8F6F4_VEC_TYPE_B, \
|
||||
OPSEL_A_L, \
|
||||
OPSEL_A_H, \
|
||||
OPSEL_B_L, \
|
||||
OPSEL_B_H) \
|
||||
if constexpr((OpselA == 1 * OPSEL_A_L + 2 * OPSEL_A_H) && \
|
||||
(OpselB == 1 * OPSEL_B_L + 2 * OPSEL_B_H)) \
|
||||
asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 " \
|
||||
"op_sel:[" #OPSEL_A_L "," #OPSEL_A_H "] " \
|
||||
"op_sel_hi:[" #OPSEL_B_L "," #OPSEL_B_H "] " \
|
||||
"cbsz:" #OPF_F8F6F4_CTRL_A " blgp:" #OPF_F8F6F4_CTRL_B \
|
||||
: "+v"(reg_c.template AsType<float4_t>()(Number<0>{})) \
|
||||
: "v"(bit_cast<F8F6F4_VEC_TYPE_A>(reg_a)), \
|
||||
"v"(bit_cast<F8F6F4_VEC_TYPE_B>(reg_b)), \
|
||||
"v"(reg_c.template AsType<float4_t>()[Number<0>{}]), \
|
||||
"v"(scale_a), \
|
||||
"v"(scale_b))
|
||||
#define BOOL4_CASES(F) \
|
||||
do \
|
||||
{ \
|
||||
F(0, 0, 0, 0); \
|
||||
F(0, 0, 0, 1); \
|
||||
F(0, 0, 1, 0); \
|
||||
F(0, 0, 1, 1); \
|
||||
F(0, 1, 0, 0); \
|
||||
F(0, 1, 0, 1); \
|
||||
F(0, 1, 1, 0); \
|
||||
F(0, 1, 1, 1); \
|
||||
F(1, 0, 0, 0); \
|
||||
F(1, 0, 0, 1); \
|
||||
F(1, 0, 1, 0); \
|
||||
F(1, 0, 1, 1); \
|
||||
F(1, 1, 0, 0); \
|
||||
F(1, 1, 0, 1); \
|
||||
F(1, 1, 1, 0); \
|
||||
F(1, 1, 1, 1); \
|
||||
} while(0)
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
@@ -943,7 +896,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
@@ -956,11 +908,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
#define f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 0, int32x8_t, int32x8_t, __VA_ARGS__)
|
||||
BOOL4_CASES(f8_cases);
|
||||
#undef f8_cases
|
||||
#endif
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
@@ -978,7 +925,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
@@ -991,10 +937,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
#define bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 1, int32x8_t, int32x8_t, __VA_ARGS__)
|
||||
BOOL4_CASES(bf8_cases);
|
||||
#endif
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
@@ -1012,7 +954,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
@@ -1025,11 +966,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
#define f8bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 1, int32x8_t, int32x8_t, __VA_ARGS__)
|
||||
BOOL4_CASES(f8bf8_cases);
|
||||
#undef f8bf8_cases
|
||||
#endif
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
@@ -1047,7 +983,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
@@ -1060,11 +995,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
#define bf8f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 0, int32x8_t, int32x8_t, __VA_ARGS__)
|
||||
BOOL4_CASES(bf8f8_cases);
|
||||
#undef bf8f8_cases
|
||||
#endif
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
@@ -1141,24 +1071,13 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void
|
||||
Run(const f4x32_t& reg_a, // misalignment between pk_f4_t, 32 and f4_t, 32
|
||||
const int32_t scale_a,
|
||||
const f4x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f4x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_local_1d_id()){
|
||||
printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n",
|
||||
get_thread_local_1d_id(),
|
||||
*reinterpret_cast<const uint32_t*>(&scale_a), *reinterpret_cast<const
|
||||
uint32_t*>(&scale_b),
|
||||
OpselA, OpselB);
|
||||
}
|
||||
#endif
|
||||
#if defined(__gfx950__)
|
||||
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
using arg_type = int32x8_t;
|
||||
@@ -1173,11 +1092,6 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
scale_a,
|
||||
OpselB, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
#define f4_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(4, 4, int32x4_t, int32x4_t, __VA_ARGS__)
|
||||
BOOL4_CASES(f4_cases);
|
||||
#undef f4_cases
|
||||
#endif
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
@@ -1186,9 +1100,7 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
#undef BOOL4_CASES
|
||||
#undef V_MFMA_SCALE_F32_16X16X128_F8F6F4
|
||||
}; // namespace ck
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x128f8f6f4;
|
||||
|
||||
@@ -165,17 +165,6 @@ inline constexpr bool is_native_type()
|
||||
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct is_f8f6f4
|
||||
{
|
||||
static constexpr bool value =
|
||||
is_same_v<T, f8_t> || is_same_v<T, bf8_t> || is_same_v<T, f6_t> || is_same_v<T, bf6_t> ||
|
||||
is_same_v<T, f6x16_pk_t> || is_same_v<T, f6x32_pk_t> || is_same_v<T, bf6x16_pk_t> ||
|
||||
is_same_v<T, bf6x32_pk_t> || is_same_v<T, f4_t> || is_same_v<T, f4x2_pk_t>;
|
||||
};
|
||||
template <typename T>
|
||||
inline constexpr bool is_f8f6f4_v = is_f8f6f4<T>::value;
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
@@ -87,6 +87,19 @@ __device__ static bool is_thread_local_1d_id_idx()
|
||||
return ((tid == Ids) || ...);
|
||||
}
|
||||
|
||||
// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
|
||||
// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
|
||||
// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
|
||||
// Set BUILD_DEV to OFF to avoid enabling Werror
|
||||
template <auto... val>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
template <typename... type>
|
||||
[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
|
||||
{
|
||||
}
|
||||
|
||||
} // namespace debug
|
||||
} // namespace ck
|
||||
|
||||
|
||||
@@ -1036,11 +1036,11 @@ struct vector_type<T, 128, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
StaticallyIndexedArray<d32_t, 4> d32x4_;
|
||||
StaticallyIndexedArray<d64_t, 2> d64x2_;
|
||||
StaticallyIndexedArray<d128_t, 1> d128x1_;
|
||||
} data_;
|
||||
} data_ = {d128_t{0}};
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type() {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; }
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
@@ -1164,11 +1164,11 @@ struct vector_type<T, 256, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
StaticallyIndexedArray<d64_t, 4> d64x4_;
|
||||
StaticallyIndexedArray<d128_t, 2> d128x2_;
|
||||
StaticallyIndexedArray<d256_t, 1> d256x1_;
|
||||
} data_;
|
||||
} data_ = {d256_t{0}};
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type() {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
__attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; }
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
@@ -2228,7 +2228,9 @@ using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
|
||||
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
|
||||
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
|
||||
|
||||
// e8m0
|
||||
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;
|
||||
|
||||
// pack int4
|
||||
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
|
||||
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck/utility/functional.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -107,7 +108,7 @@ struct identity
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T&& operator()(T&& arg) const noexcept
|
||||
{
|
||||
return forward<T>(arg);
|
||||
return ck::forward<T>(arg);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user