mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
[Block Scale GEMM] Optimized block scale gemm (#1950)
* Added two kernel for M=32 problem
* Comment the first one
* Enable multiply_multiply for Scale_Block_M = 1 for deepseek
* Modify the a_thread offset since the A data load is different from B.
* edit fp8 ab scale for Scale_Block_M=1
* edit GemmSpec to MNKPadding
* enable blockwise pipelie v1 and v2. v1 is work for small K.
* add instance for gemm_ab_scale
* fix cmakelist of ckProfiler
* optimize blockscale gemm. todo: reduce vgpr usage
* fix a correctness bug
* sanity checked
* revert ckprofiler cmake changes
* clang format
* revert unnecessary changes.
* remove commented codes.
* split weight preshuffle library targets
* bring back enable-post-misched=0
* fix build issues for gemm_multiply_multiply_fp8 instances
* fix clang format
* add verbose build flag when building for all targets
* reduce path names for new instances
* fix paths in cmake
* refactor gemm_multiply_multiply library target
* fix a bug in example
* fix example 65 cmake
* reduce the number of threads when building libs for all targets to 50
* use ninja to build for all targets
* reduce teh number of threads when building for all targets
* reduce the number of threads to 32 when building libs for all targets to 50
---------
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: chenjun <junchen2@amd.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: cbd74c2d12]
This commit is contained in:
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -1145,11 +1145,11 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j32 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -55,7 +55,7 @@ using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 128;
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
128, 128,
|
||||
128, 16, 16,
|
||||
16, 128,
|
||||
256, 16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
1, 2,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 16, 1, 16>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -80,11 +80,12 @@ int main(int argc, char* argv[])
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool flush_cache = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t M = 128;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
@@ -100,7 +101,7 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
else if(argc == 8)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
@@ -110,16 +111,19 @@ int main(int argc, char* argv[])
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
flush_cache = std::stoi(argv[7]);
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n");
|
||||
printf("arg4 to 6: M, N, K\n");
|
||||
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
@@ -182,9 +186,15 @@ int main(int argc, char* argv[])
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
@@ -194,6 +204,16 @@ int main(int argc, char* argv[])
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
#endif
|
||||
#if 0
|
||||
for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){
|
||||
float row_sum = .0;
|
||||
for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){
|
||||
printf("%lf ",a1_m_k(im, ik));
|
||||
row_sum += a1_m_k(im, ik);
|
||||
}
|
||||
printf("sum: %lf\n", row_sum * 128);
|
||||
}
|
||||
#endif
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
|
||||
@@ -239,12 +259,24 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float ave_time = .0;
|
||||
|
||||
if(flush_cache)
|
||||
{
|
||||
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
|
||||
|
||||
ave_time = invoker.Run(argument,
|
||||
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
@@ -140,14 +140,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
|
||||
256, 256, 128,
|
||||
128, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
8, 8,
|
||||
32, 32,
|
||||
2, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -453,7 +453,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
|
||||
@@ -784,7 +784,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 1
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -117,10 +118,15 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
true>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
@@ -131,19 +137,43 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
static constexpr index_t AMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
|
||||
static constexpr index_t BMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
// Force mfma not cross the scaleblock
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
@@ -151,11 +181,116 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
return num_loop == 1 ? TailNumber::Odd : TailNumber::Full;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_a_mfma =
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
// : sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
|
||||
ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
@@ -169,6 +304,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
@@ -196,6 +332,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
@@ -210,11 +347,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
index_t num_loop) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// assume kperblock = scaleblockk
|
||||
ignore = num_loop_per_scale;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
@@ -223,6 +359,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -231,11 +369,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -243,17 +396,101 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
@@ -261,13 +498,85 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -289,19 +598,70 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
i += 1;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -311,46 +671,41 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
c_scale_thread_buf(Number<c_offset>{}) =
|
||||
a_scale_thread_buf[Number<a_offset>{}] *
|
||||
b_scale_thread_buf[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -371,49 +726,143 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0,
|
||||
I0,
|
||||
kscale0 * KRepeat / num_scale_k_block + k0,
|
||||
ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
|
||||
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(
|
||||
c_scale_thread_buf[Number<cscale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
@@ -270,11 +272,26 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(num_loop_per_scale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -282,7 +299,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -360,17 +376,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(num_loop_per_scale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -378,8 +409,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step);
|
||||
|
||||
@@ -453,17 +482,32 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(num_loop_per_scale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -471,7 +515,6 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -528,7 +571,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
@@ -586,7 +629,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[m0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -96,7 +96,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
true>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -117,7 +118,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
@@ -177,11 +179,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle = 4;
|
||||
// HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle = 4;
|
||||
// HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
@@ -262,6 +264,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
int NumKBlockPerScale,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
@@ -275,6 +278,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CScaleThreadDesc,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
@@ -302,6 +306,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
const CScaleThreadDesc& c_scale_thread_desc,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// AScaleThreadCopy
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
@@ -316,12 +321,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleThreadTransferStep& b_scale_thread_copy_step,
|
||||
// num_loop
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
index_t num_loop) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1,
|
||||
"Pipeline v3 only support scaleblocksliceK=1");
|
||||
static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
|
||||
"Pipeline v3 only support scaleblocksliceN=1");
|
||||
// assume kperblock = scaleblockk
|
||||
ignore = num_loop_per_scale;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
@@ -330,6 +337,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -338,11 +347,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -350,8 +374,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
@@ -363,10 +391,44 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
1,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
@@ -409,7 +471,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
@@ -430,19 +495,23 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
@@ -462,11 +531,27 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
@@ -474,7 +559,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -487,7 +571,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
@@ -507,15 +594,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<t>{}] *
|
||||
type_convert<AccDataType>(a_scale_thread_buf[I0]) *
|
||||
type_convert<AccDataType>(b_scale_thread_buf[I0]);
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -177,14 +178,57 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
@@ -195,7 +239,7 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always 1
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -208,127 +252,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
@@ -337,6 +267,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
@@ -363,10 +303,11 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != KPerBlock)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
|
||||
// KPerBlock)
|
||||
// {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
|
||||
@@ -234,7 +234,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
__device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
|
||||
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -316,7 +316,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
|
||||
index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
|
||||
{
|
||||
const auto b_grid_desc_nraw_kraw = [&]() {
|
||||
@@ -431,6 +431,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}();
|
||||
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(M, MPad - M),
|
||||
make_right_pad_transform(N, NPad - N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
#if 0
|
||||
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
|
||||
@@ -468,6 +475,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
// not pad M or N
|
||||
return c_grid_desc_mraw_nraw;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
@@ -665,40 +673,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
// in some cases.
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(LDSTypeA);
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock * MLdsLayer>{}, I1));
|
||||
constexpr auto a_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(
|
||||
make_tuple(Number<MPerBlock>{}, Number<AK0Number>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number<MLdsLayer>{})),
|
||||
make_pass_through_transform(Number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_ak0_mldslayer_m_ak1,
|
||||
make_tuple(make_pass_through_transform(AK0Number),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(Number<MPerBlock / MLdsLayer>{}, Number<MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return a_lds_block_desc_ak0_m_ak1;
|
||||
return a_lds_block_desc_permuted;
|
||||
}
|
||||
else // ColumnMajor A
|
||||
{
|
||||
@@ -800,42 +787,19 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(LDSTypeB);
|
||||
;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock * NLdsLayer>{}, I1));
|
||||
constexpr auto b_lds_block_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(
|
||||
make_tuple(Number<NPerBlock>{}, Number<BK0Number>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number<NLdsLayer>{})),
|
||||
make_pass_through_transform(Number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(make_pass_through_transform(BK0Number),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(Number<NPerBlock / NLdsLayer>{}, Number<NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
return b_lds_block_desc_bk0_n_bk1;
|
||||
return b_lds_block_desc_permuted;
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
@@ -1001,7 +965,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
@@ -1018,7 +983,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
@@ -1366,28 +1332,39 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
const index_t ScaleSliceSizeM = 1;
|
||||
const index_t ScaleSliceSizeN = 1;
|
||||
const index_t ScaleSliceSizeK = 1;
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
// ScaleSliceSizeK is last dimension in A/B scale for vector memory access
|
||||
// ScaleSliceSizeK is first dimension in C scale for packed math
|
||||
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
auto a_thread_offset =
|
||||
get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl;
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<ScaleSliceSizeK>{}, Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeN>{}));
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleType,
|
||||
AScaleType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_desc),
|
||||
Sequence<ScaleSliceSizeM, ScaleSliceSizeK>,
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
a_scale_grid_desc_am_ak, make_multi_index(block_m_id * MPerBlock / ScaleBlockM, 0));
|
||||
a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
@@ -1397,17 +1374,21 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
|
||||
|
||||
constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
// constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
|
||||
constexpr auto a_scale_thread_slice_copy_step =
|
||||
make_tuple(make_multi_index(MWaves * MPerXdl, 0),
|
||||
make_multi_index(-MPerBlock, 0),
|
||||
make_multi_index(-MPerBlock, ScaleSliceSizeK));
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
const index_t num_k_block_per_scale = ScaleBlockK / KPerBlock;
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
|
||||
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
@@ -1420,6 +1401,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
c_scale_thread_desc,
|
||||
c_thread_buf,
|
||||
|
||||
a_scale_grid_desc_am_ak,
|
||||
@@ -1434,8 +1417,7 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_slice_copy_step,
|
||||
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
num_k_block_main_loop);
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
@@ -1446,23 +1428,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
// transposed XDL
|
||||
// // TODO: hacky, fix it!
|
||||
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
|
||||
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
// TODO: hacky, fix it!
|
||||
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
|
||||
// // TODO: hacky, fix it!
|
||||
// only used to get lengths
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
|
||||
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
|
||||
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
|
||||
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
|
||||
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
|
||||
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
|
||||
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
|
||||
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
|
||||
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
|
||||
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
|
||||
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
|
||||
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
|
||||
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
|
||||
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
|
||||
@@ -1471,24 +1454,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
|
||||
M1, // M1 = MWave
|
||||
M2, // M2 * M3 * M4 = MPerXdl
|
||||
M3,
|
||||
M4)),
|
||||
M2)), // M2 = MPerXdl
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
|
||||
N1, // N1 = NWave
|
||||
N2))), // N2 = NPerXdl
|
||||
N2, // N2 * N3 * N4 = NPerXdl
|
||||
N3,
|
||||
N4))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(
|
||||
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
|
||||
Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -1498,57 +1481,57 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
|
||||
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
|
||||
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
|
||||
const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
|
||||
make_tuple(make_merge_transform(make_tuple(M0, M1, M2))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto m_thread_data_on_block_idx =
|
||||
m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(m_thread_data_on_block));
|
||||
|
||||
const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto n_thread_data_on_block_idx =
|
||||
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
|
||||
n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
|
||||
tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
I1,
|
||||
I1,
|
||||
M2,
|
||||
I1,
|
||||
M4,
|
||||
I1>,
|
||||
N2,
|
||||
I1,
|
||||
N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
7,
|
||||
1,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
m_thread_data_on_block_idx[I1],
|
||||
n_thread_data_on_block_idx[I1],
|
||||
m_thread_data_on_block_idx[I2],
|
||||
m_thread_data_on_block_idx[I3],
|
||||
m_thread_data_on_block_idx[I4],
|
||||
n_thread_data_on_block_idx[I2]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
n_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I3],
|
||||
n_thread_data_on_block_idx[I4]),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
@@ -1630,18 +1613,17 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
|
||||
c_element_op};
|
||||
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
|
||||
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, 1, N2, 1, N4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
Sequence<CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
1,
|
||||
1,
|
||||
M2,
|
||||
1,
|
||||
M4,
|
||||
1>>{};
|
||||
N2,
|
||||
1,
|
||||
N4>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
|
||||
@@ -1661,10 +1643,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
|
||||
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
|
||||
c_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -28,14 +28,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -46,14 +46,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -64,14 +64,14 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -82,61 +82,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpaddin
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -163,7 +109,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -180,7 +126,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
B1DataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
@@ -198,20 +144,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +122,32 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kp
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
@@ -360,6 +386,32 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpa
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
@@ -556,6 +608,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part3(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part3(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
@@ -593,6 +649,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part3(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part3(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<F32, F32>,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyMultiply>>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<F32, F32>,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyMultiply>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
// TODO: Add MFMA layout into tensor layout
|
||||
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,664 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>&
|
||||
instances);
|
||||
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<F32, F32>,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyMultiply>>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceGemmMultipleDSplitKBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Tuple<F32, F32>,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyMultiply>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
// TODO: Add MFMA layout into tensor layout
|
||||
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
op_ptrs);
|
||||
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -77,7 +77,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_f8")
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_")
|
||||
message("removing gemm_multiply_multiply_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -88,18 +88,6 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "batched_gemm_xdl_universal" AND source MATCHES "_f8_")
|
||||
message("removing batched_gemm_universal_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal_streamk" AND source MATCHES "_f8_")
|
||||
message("removing gemm_universal_streamk_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
@@ -113,17 +101,21 @@ function(add_instance_library INSTANCE_NAME)
|
||||
elseif(source MATCHES "mha")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
|
||||
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
if(source MATCHES "gemm_multiply_multiply_f8")
|
||||
if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
if(source MATCHES "bached_gemm_multiply_multiply_f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
else()
|
||||
if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
endif()
|
||||
set(offload_targets)
|
||||
foreach(target IN LISTS INST_TARGETS)
|
||||
@@ -243,7 +235,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12"))
|
||||
message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
@@ -251,14 +243,18 @@ FOREACH(subdir_path ${dir_list})
|
||||
message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9"))
|
||||
message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
||||
if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12") AND (NOT INST_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS))
|
||||
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if((add_inst EQUAL 1))
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
|
||||
@@ -4,16 +4,13 @@ set(GEMM_AB_SCALE_INSTANCES)
|
||||
list(APPEND GEMM_AB_SCALE_INSTANCES
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
add_instance_library(device_gemm_ab_scale_instance ${GEMM_AB_SCALE_INSTANCES})
|
||||
|
||||
@@ -34,49 +34,50 @@ static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances = std::tuple<
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Compute friendly
|
||||
// Spill in current compiler
|
||||
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances = std::tuple<
|
||||
using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData | BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
// Memory friendly
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 32, 128, 16, 16, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 32, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 64, 16, 128, 16, 16, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 64, 16, 16, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 128, 128, 128, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 64, 128, 16, 16, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 16, 128, 128, 16, 16, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>
|
||||
// Memory friendly
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8,F32, F8,F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_i
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmDefault>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,7 +28,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmKPadding>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,37 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
|
||||
GemmDefault>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances(
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
@@ -19,7 +19,7 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
@@ -28,8 +28,8 @@ void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpaddin
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
|
||||
GemmKPadding>{});
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances<Intrawave,
|
||||
GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD_ABScale<Row,
|
||||
Col,
|
||||
Tuple<>,
|
||||
Row,
|
||||
F8,
|
||||
F32,
|
||||
F8,
|
||||
F32,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances<Intrawave,
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,6 +10,8 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part3.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part3.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
@@ -23,6 +25,8 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part3.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part3.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
@@ -44,6 +48,8 @@ set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instance_part3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
@@ -53,6 +59,8 @@ set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_g
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instance_part3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_f16/device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instance_part3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
@@ -36,12 +36,11 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1 = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -60,18 +59,16 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2 = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
|
||||
// Compute friendly
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 3, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -100,19 +97,17 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 192, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part1 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
// 256x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -125,7 +120,20 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_insta
|
||||
// 224x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Compute friendly
|
||||
// 224x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -139,19 +147,17 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_insta
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 256, 16, 16, 16, 16, 6, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part2 =
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part3 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
|
||||
// 160x[64, 256, 32]x128, 160x[64, 96, 32]x256
|
||||
@@ -180,18 +186,16 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_insta
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Latency friendly
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
@@ -227,8 +231,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std:
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
|
||||
// clang-format on
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_default_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part3<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mfma16x16_instances_part3<
|
||||
GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -36,12 +36,11 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part1 = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -59,18 +58,16 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part1
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 32, 32, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part2 = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
|
||||
// Compute friendly
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -92,19 +89,17 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_instances_part2
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 192, 256, 16, 16, 16, 16, 1, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part1 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
// 256x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -117,7 +112,20 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instan
|
||||
// 224x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part2 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Compute friendly
|
||||
// 224x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -131,19 +139,17 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instan
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 256, 16, 16, 16, 16, 6, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part2 =
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part3 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
|
||||
// 160x[64, 256, 32]x128, 160x[64, 96, 32]x256
|
||||
@@ -167,18 +173,16 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instan
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Latency friendly
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2, 2, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4, 4, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
@@ -214,8 +218,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_instances = std::
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
|
||||
// clang-format on
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_default_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part3<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_kpadding_instances_part3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_mfma16x16_instances_part3<
|
||||
GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,42 +0,0 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp
|
||||
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES})
|
||||
@@ -0,0 +1,82 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p4.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp
|
||||
f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp
|
||||
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp
|
||||
f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
add_instance_library(device_gemm_multiply_multiply_wp_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES})
|
||||
@@ -1,14 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances(
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
@@ -23,7 +23,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances<
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p1<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances(
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
@@ -23,7 +23,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma1
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances<
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p2<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p3<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p4<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p5<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p6<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -40,107 +40,125 @@ static constexpr auto v2 = BlockGemmPipelineVersion::v2;
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 512
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 512, 16, 16, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 256, 16, 16, 32, 32, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 512
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
#endif
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// N 256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 512
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
#endif
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances =
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances_p1 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 32, 32, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 32, 32, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 32, 32, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances_p2 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 32, 32, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 6, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances =
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p1 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
// 256x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -149,7 +167,18 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p2 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// 224x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -157,7 +186,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p3 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// 192x[64, 256, 32]x128, 192x[64]x256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -165,7 +204,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p4 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// 160x[64, 256, 32]x128, 160x[64, 96, 32]x256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -173,19 +222,38 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
// 128x[64, 256, 32]x128, 128x[64, 128, 32]x256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p5 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p6 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances_p1<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_instances_p2<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_instances<
|
||||
v1,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_instances<
|
||||
v2,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_instances<
|
||||
v1,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_instances<
|
||||
v2,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,14 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances(
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
@@ -23,7 +23,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances<
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p1<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p2<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p3<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p4<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p5<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p6<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -40,107 +40,125 @@ static constexpr auto v2 = BlockGemmPipelineVersion::v2;
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 512
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 32, 32, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 128, 512, 16, 16, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 256, 16, 16, 32, 32, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 512
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
#endif
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// N 256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
// N 512
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
#endif
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipeVer, GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances =
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p1 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 32, 32, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 32, 32, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 32, 32, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p2 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 32, 32, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 6, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances =
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p1 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
// Compute friendly
|
||||
// 256x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -149,7 +167,18 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p2 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// 224x[64, 256, 32]x128
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -157,7 +186,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p3 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// 192x[64, 256, 32]x128, 192x[64]x256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -165,7 +204,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p4 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// 160x[64, 256, 32]x128, 160x[64, 96, 32]x256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
@@ -173,19 +222,38 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
// 128x[64, 256, 32]x128, 128x[64, 128, 32]x256
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p5 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
#endif
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p6 =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
|
||||
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple<Row, Col>, Row, F8, F8, Tuple<F32, F32>, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances(
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
@@ -23,7 +23,7 @@ void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_m
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances<
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p1<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p2<
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -1,7 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_instances<
|
||||
v1,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_instances<
|
||||
v2,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_instances<
|
||||
v1,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitKBPreShuffle<Row,
|
||||
Col,
|
||||
Tuple<Row, Col>,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
Tuple<F32, F32>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_instances<
|
||||
v2,
|
||||
GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
@@ -50,7 +50,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_wp.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
|
||||
@@ -140,7 +140,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_wp_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
|
||||
endif()
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
|
||||
@@ -32,6 +32,7 @@ enum struct GemmDataType
|
||||
enum struct ScaleBlockTile
|
||||
{
|
||||
Tile_128_128_128, // 0
|
||||
Tile_1_128_128, // 1
|
||||
};
|
||||
|
||||
#define OP_NAME "gemm_ab_scale"
|
||||
@@ -49,7 +50,8 @@ int profile_gemm_ab_scale(int argc, char* argv[])
|
||||
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
|
||||
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
|
||||
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
|
||||
printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128];\n");
|
||||
printf("arg4: scale block tile (0: ScaleBlockM/N/K = [128, 128, 128]; 1: ScaleBlockM/N/K = "
|
||||
"[1, 128, 128];\n");
|
||||
printf("arg5: verification (0: no; 1: yes)\n");
|
||||
printf("arg6: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg7: print tensor value (0: no; 1: yes)\n");
|
||||
@@ -155,7 +157,7 @@ int profile_gemm_ab_scale(int argc, char* argv[])
|
||||
};
|
||||
|
||||
if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN &&
|
||||
scale_block_tile == ScaleBlockTile::Tile_128_128_128)
|
||||
scale_block_tile == ScaleBlockTile::Tile_1_128_128)
|
||||
{
|
||||
return profile(F8{},
|
||||
F32{},
|
||||
@@ -164,7 +166,7 @@ int profile_gemm_ab_scale(int argc, char* argv[])
|
||||
F8{},
|
||||
F32{},
|
||||
BF16{},
|
||||
ck::Number<128>{},
|
||||
ck::Number<1>{},
|
||||
ck::Number<128>{},
|
||||
ck::Number<128>{},
|
||||
Row{},
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/profile_gemm_multiply_multiply_weight_preshuffle_impl.hpp"
|
||||
#include "profiler/profile_gemm_multiply_multiply_wp_impl.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
enum struct GemmMatrixLayout
|
||||
Reference in New Issue
Block a user