mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[GEMM] Gemm universal device operation (#1154)
* Optimize GEMM on MI200/300:
1. Add new blockwise gemm pipeline
2. Add irregular splitk intances
* clang format + typo fix
* Fix a bug
* initial commit
* Add more instances to irregular splitk
* blkgemm pipeline v1~4 prototype
* Sanity Checked. Known issue:
1. Poor performance of splitk
2. Register spill on blkgemmpipeline v3
* Sanity and Performance fix:
1. fix a bug related to sanity in grouped b2c mapping
2. fix a bug related to sanity and performance in splitk offset
* Sanity and API update:
1. Remove prefetch stage
2. Fix valid check bug
3, Add first gemm_universal instance into ckProfiler
* Add NN instances for gemm universal
* 1. Add NT instances for gemm_universal
2. Fix a bug about Kpadding in gemm_universal
* Fix a bug regarding padding Odd K number
* remove kernel print
* Fix KPadding bug...
* Update safety check
* another try to fix kpadding..
* Sanity checked
* new instances..
* clang format+typo fix
* remove clang format script's change
* Add non-hotloop compile option
* 1. Add fp16xfp8 example
2. pull packed convert f8 from pr1150
* Some miscs.. opt and fix
* Add pipeline description docs
* Split universal gemm instance library to cut profiler compiling time
* uncomment cmakefile
* Fix a bug caused by blockwise_gemm_pipe_v2
* reduce default splitk to 1
* Add 224x256x64 tile size
* update, including:
1. Experiment pipeline 5~7
2. Optimization for pipeline 4
3. Organized instance library
* temp save
* temp save
* Permuted lds layout, sanity and function checked
* clang format
* Move OOB check from RunRead to RunWrite, for better software pipeline.
TODO: agpr spill when NN layout
* clangformat
* A/B splitpipe scheduler for v3
* Fix two bugs
* bug fix
* fix a bug in oob check
* Example for mixed fp16_fp8 gemm
* Clean experimental code blocks
* Add mixed precision gemm into profiler
* tempsave
* optimize m/n major lds layout
* Add RRR GEMM mixed precision instances
* Optimize f8 matrix transpose
* Add test_gemm_universal
* A/B spilt schedule for blkpip v5
* Take ds_read2 into iglp scheduling scheme
* format
* fixed cmake
* Add llvm-option into CI cmake flag
---------
Co-authored-by: Jing Zhang <jizhan@amd.com>
[ROCm/composable_kernel commit: f83e9701e9]
This commit is contained in:
14
Jenkinsfile
vendored
14
Jenkinsfile
vendored
@@ -824,7 +824,7 @@ pipeline {
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check"""
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j check"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -848,12 +848,12 @@ pipeline {
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
|
||||
-DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
@@ -868,12 +868,12 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
@@ -888,12 +888,12 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx908 || gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
|
||||
@@ -22,6 +22,13 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
|
||||
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
|
||||
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
|
||||
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
|
||||
@@ -46,6 +46,19 @@ struct ProblemSizeStreamK final
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -158,3 +171,52 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSizeSplitK>(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeSplitK& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc >= 10)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideC = std::stoi(argv[9]);
|
||||
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[10]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: KBatch" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
53
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
Normal file
53
example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp
Normal file
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
64,
|
||||
16, 16,
|
||||
64, 16, 8,
|
||||
16, 16,
|
||||
1, 1,
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 0,
|
||||
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 4>, 4,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
48
example/01_gemm/gemm_xdl_fp16_v3.cpp
Normal file
48
example/01_gemm/gemm_xdl_fp16_v3.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
224, 256,
|
||||
64, 8, 2,
|
||||
16, 16,
|
||||
7, 8,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 8, 2, 0,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
48
example/01_gemm/gemm_xdl_fp8_v3.cpp
Normal file
48
example/01_gemm/gemm_xdl_fp8_v3.cpp
Normal file
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
128, 256,
|
||||
128, 16, 16,
|
||||
16, 16,
|
||||
4, 8,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 16, 16, 1,
|
||||
1, 2, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
211
example/01_gemm/run_gemm_example_v2.inc
Normal file
211
example/01_gemm/run_gemm_example_v2.inc
Normal file
@@ -0,0 +1,211 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
#endif
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
#if 0
|
||||
printf("B matrix:\n");
|
||||
for (int in = 0; in < N; in++)
|
||||
{
|
||||
for (int ik = 0; ik < K; ik++)
|
||||
{
|
||||
printf("%02x ", *(reinterpret_cast<uint8_t*>(&b_k_n(ik,in))));
|
||||
if(ik%8==7) printf("|");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) *
|
||||
c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
const Tensor<KernelADataType> a_m_k_converted(a_m_k);
|
||||
const Tensor<KernelBDataType> b_k_n_converted(b_k_n);
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data());
|
||||
#else
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
#endif
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1});
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
Tensor<CDataType> c_m_n_device_result_converted(c_m_n_host_result.mDesc);
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data());
|
||||
|
||||
c_m_n_device_result = c_m_n_device_result_converted.CopyAsType<CDataType>();
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result);
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
#endif
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
@@ -1951,4 +1951,89 @@ struct Modulo
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLengths>
|
||||
struct Xor
|
||||
{
|
||||
using LowerIndex = MultiIndex<2>;
|
||||
using UpperIndex = MultiIndex<2>;
|
||||
|
||||
using UpLengths = LowLengths;
|
||||
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr Xor() : up_lengths_{} {}
|
||||
|
||||
__host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(Number<0>{}) = idx_up[Number<0>{}];
|
||||
|
||||
idx_low(Number<1>{}) =
|
||||
idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]);
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
typename UpIdxDiff,
|
||||
typename LowIdx,
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 &&
|
||||
UpIdx::Size() == 2,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
const auto idx_low_old = idx_low;
|
||||
|
||||
CalculateLowerIndex(idx_low, idx_up);
|
||||
|
||||
idx_diff_low = idx_low - idx_low_old;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ static constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<UpLengths>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("Xor{");
|
||||
|
||||
//
|
||||
printf("up_lengths_: ");
|
||||
print(up_lengths_);
|
||||
printf(", ");
|
||||
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
} // namespace ck
|
||||
|
||||
@@ -127,4 +127,10 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
|
||||
{
|
||||
return Modulo<Modulus, UpLength>{modulus, up_length};
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
return Xor<LowLengths>{low_lengths};
|
||||
}
|
||||
} // namespace ck
|
||||
|
||||
@@ -960,13 +960,13 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(
|
||||
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
|
||||
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
|
||||
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(
|
||||
Number<KPack>{}, Number<KPack * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
|
||||
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmXdlops_pipeline_base
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
|
||||
|
||||
static constexpr index_t AMmaKStride = KPack;
|
||||
static constexpr index_t BMmaKStride = KPack;
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
using HotLoopInstList =
|
||||
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
A_K1,
|
||||
B_K1,
|
||||
A_K1,
|
||||
B_K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
xdlops_gemm.KPerXdlops>;
|
||||
|
||||
static_assert(KPerThread % KPack == 0,
|
||||
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
MRepeat * NRepeat,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = ThisThreadBlock::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
|
||||
|
||||
return make_tuple(
|
||||
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
|
||||
}
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
|
||||
}
|
||||
|
||||
// XDL output supporting C_xdl = A_xdl * B_xdl
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
// XDL output supporting C_xdl = A_xdl * B_xdl
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1,
|
||||
Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_block_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_G_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_g_m_n,
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
|
||||
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
|
||||
|
||||
protected:
|
||||
// M1, N1 as double buffer index
|
||||
// Read buffer + Compute buffer
|
||||
// A[M0, M1, M2, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(
|
||||
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
|
||||
|
||||
// B[N0, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
|
||||
make_tuple(
|
||||
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct BlockGemmPipelineVersion
|
||||
{
|
||||
v1, // Naive
|
||||
v2, // Mem
|
||||
v3, // Comp
|
||||
v4, // Comp, double lds buffer
|
||||
v5, // Comp, double global prefetch register buffer
|
||||
};
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v1<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v2<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v3<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v4<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,732 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 1
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KPerThread;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
|
||||
// but except the first, as we can shorten non-MAC cluster a bit and there's no
|
||||
// observable negative impact. The desired effect is waves in a workgroup
|
||||
// executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
|
||||
// resource from other workgroups and reducing the chance of latency hiding by
|
||||
// waiting for the rest of the workgroup at the eventual sync point.
|
||||
if constexpr(k0.value != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard because barrier from
|
||||
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion by
|
||||
// applying small delays to different wavefronts It is performed
|
||||
// near the end of MAC cluster to minimize lgkmcnt penalty
|
||||
if constexpr(k0.value == KRepeat - 1 &&
|
||||
k_.value == KPerInnerLoop - KPack &&
|
||||
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
|
||||
// block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(k0.value != 0 || KRepeat == 1)
|
||||
{
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, k_ + ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, k_ + ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
if constexpr(k0.value == KRepeat - 1 &&
|
||||
k_.value == KPerInnerLoop - KPack &&
|
||||
m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// K->M loopover
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
|
||||
make_tuple(Number<KPerInnerLoop>{},
|
||||
Number<KRepeat * MRepeat * KPerInnerLoop>{},
|
||||
Number<MRepeat * KPerInnerLoop>{},
|
||||
I1));
|
||||
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
|
||||
make_tuple(Number<KPerInnerLoop>{},
|
||||
Number<KRepeat * NRepeat * KPerInnerLoop>{},
|
||||
Number<NRepeat * KPerInnerLoop>{},
|
||||
I1));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerInnerLoop>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,439 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_v3
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / ds_read_a_issue_cycle;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / ds_read_b_issue_cycle;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
sizeof(ComputeDataType) / sizeof(BDataType)
|
||||
? sizeof(ComputeDataType) / sizeof(ADataType)
|
||||
: sizeof(ComputeDataType) / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 =
|
||||
num_mfma_inst - num_mfma_per_ds_read * (num_ds_read_inst_a / ds_read_a_mfma_rate +
|
||||
num_ds_read_inst_b / ds_read_b_mfma_rate);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_ds_read_inst_a / ds_read_a_mfma_rate, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_ds_read_inst_b / ds_read_b_mfma_rate, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,597 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimimal pipeline with highest resource request
|
||||
// GlobalPrefetchStages: 4
|
||||
// LocalPreFillStages: 2
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 2
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_v4
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t PrefetchStages = 4;
|
||||
static constexpr index_t PrefillStages = 2;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t HotloopUnroll = 2;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % HotloopUnroll == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ScheduleGroup>
|
||||
__device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group)
|
||||
{
|
||||
// TODO: Take data type into consideration as pipe ver 3
|
||||
// A-B splited schedule
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_dswrite_per_issue_a =
|
||||
(HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
|
||||
constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
|
||||
|
||||
constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_dswrite_per_issue_b =
|
||||
(HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
|
||||
constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
|
||||
|
||||
constexpr auto num_mfma_per_issue =
|
||||
HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
|
||||
|
||||
static_for<0, num_issue_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
|
||||
ignore = idsread;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_dsread_per_issue_a -
|
||||
num_dswrite_per_issue_a,
|
||||
schedule_group); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_issue_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
|
||||
ignore = idsread;
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_dsread_per_issue_a -
|
||||
num_dswrite_per_issue_b,
|
||||
schedule_group); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0);
|
||||
|
||||
// Local prefill 2
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1);
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf.At(I0),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_bufs(I0));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf.At(I0),
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_bufs(I0));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Global prefetch 3
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Global prefetch 4
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
// This hot loop has two legacy loopover, to implement the double local buffer strategy
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto lds_read_buf,
|
||||
auto lds_read_reg_buf,
|
||||
auto lds_write_buf,
|
||||
auto vmem_buf,
|
||||
auto mfma_reg_buf,
|
||||
auto schedule_group) {
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf.At(lds_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_bufs(lds_read_reg_buf));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf.At(lds_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_bufs(lds_read_reg_buf));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_blockwise_copy.RunWrite(
|
||||
a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
|
||||
b_blockwise_copy.RunWrite(
|
||||
b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler(schedule_group);
|
||||
};
|
||||
|
||||
LoopFunc(I1, I1, I0, I0, I0, I0);
|
||||
LoopFunc(I0, I0, I1, I1, I1, I0);
|
||||
|
||||
i += HotloopUnroll;
|
||||
} while(i < (num_loop - PrefetchStages));
|
||||
}
|
||||
|
||||
auto ReadWriteCompFunc = [&](auto lds_read_buf,
|
||||
auto lds_read_reg_buf,
|
||||
auto lds_write_buf,
|
||||
auto vmem_buf,
|
||||
auto mfma_reg_buf,
|
||||
auto schedule_group) {
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf.At(lds_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_bufs(lds_read_reg_buf));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf.At(lds_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_bufs(lds_read_reg_buf));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler(schedule_group);
|
||||
};
|
||||
|
||||
auto ReadCompFunc = [&](auto lds_read_buf,
|
||||
auto lds_read_reg_buf,
|
||||
auto mfma_reg_buf,
|
||||
auto schedule_group) {
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf.At(lds_read_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_bufs(lds_read_reg_buf));
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf.At(lds_read_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_bufs(lds_read_reg_buf));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler(schedule_group);
|
||||
};
|
||||
|
||||
auto CompFunc = [&](auto mfma_reg_buf) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1);
|
||||
ReadCompFunc(I0, I0, I1, I1);
|
||||
CompFunc(I0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
ReadWriteCompFunc(I1, I1, I0, I0, I0, I1);
|
||||
ReadWriteCompFunc(I0, I0, I1, I1, I1, I1);
|
||||
ReadCompFunc(I1, I1, I0, I1);
|
||||
CompFunc(I1);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,667 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 3
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 2
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks>
|
||||
struct BlockwiseGemmXdlops_pipeline_v5
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
>
|
||||
struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::A_K1;
|
||||
using Base::B_K1;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t HotloopUnroll = 2;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % HotloopUnroll == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// TODO: Take data type into consideration as pipe ver 3
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
|
||||
? HotLoopInstList::A_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
|
||||
? HotLoopInstList::B_LDS_Read_Inst_Num
|
||||
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
|
||||
constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
|
||||
|
||||
constexpr auto num_dsread_stage1_a_mfma =
|
||||
(num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_stage1_b_mfma =
|
||||
(num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
constexpr auto num_dsread_stage3_a_mfma =
|
||||
(num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_stage3_b_mfma =
|
||||
(num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
|
||||
num_ds_read_inst_b / ds_read_b_mfma_rate;
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
// stage 1
|
||||
static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 3
|
||||
static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
// IGLP COMPILER BUG:
|
||||
// If comment out following scheduler barrier would cause sanity fail.
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Global prefetch 3
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto vmem_buf) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
if constexpr(k0 == (KRepeat - 1))
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
LoopFunc(I0);
|
||||
LoopFunc(I1);
|
||||
|
||||
i += HotloopUnroll;
|
||||
} while(i < (num_loop - PrefetchStages));
|
||||
}
|
||||
// tail
|
||||
auto ReadWriteCompFunc = [&](auto vmem_buf) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
if constexpr(k0 == (KRepeat - 1))
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
};
|
||||
auto ReadCompFunc = [&]() {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
|
||||
});
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
ReadWriteCompFunc(I0);
|
||||
ReadWriteCompFunc(I1);
|
||||
ReadCompFunc();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
ReadWriteCompFunc(I0);
|
||||
ReadCompFunc();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[MRepeat, I1, I1, KPack]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, I1, I1, Number<KPack>{}));
|
||||
|
||||
// B[NRepeat, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeDataType,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
ComputeDataType,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
43
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
Normal file
43
include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmV2 : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KSplit,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,687 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const BDataType* p_b,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation)
|
||||
{
|
||||
return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -22,6 +22,7 @@ struct PassThroughPack2
|
||||
auto t = type_convert<float2_t>(x);
|
||||
y = type_convert<half2_t>(t);
|
||||
}
|
||||
constexpr const static bool is_pack2_invocable = true;
|
||||
};
|
||||
|
||||
struct PassThrough
|
||||
@@ -131,12 +132,24 @@ struct PassThrough
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<int32_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, int8_t>(float& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
|
||||
@@ -259,46 +259,20 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
|
||||
BlockToCTileMap_M00_N0_M01Adapt;
|
||||
};
|
||||
|
||||
// Rows of column-vectors
|
||||
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt;
|
||||
// Grouped Rows of column-vectors WGP mapping
|
||||
// Optimized for MI300-like multipe-die chip
|
||||
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(
|
||||
const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
|
||||
operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt&
|
||||
operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default;
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
|
||||
index_t N,
|
||||
index_t M01 = 8)
|
||||
: M_(M), N_(N), M01_(M01)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_global_1d_id()==0){
|
||||
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8)
|
||||
: BlockToCTileMap_Grouped_M00_N0_M01Adapt(
|
||||
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
@@ -309,12 +283,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
|
||||
return M0 * N0;
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const
|
||||
{
|
||||
@@ -329,67 +297,82 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
|
||||
const auto M0 = math::integer_divide_ceil(M_, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
if(M0 == 1)
|
||||
{
|
||||
return make_tuple(0, block_1d_id);
|
||||
}
|
||||
else if(N0 == 1)
|
||||
{
|
||||
return make_tuple(block_1d_id, 0);
|
||||
}
|
||||
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
|
||||
else
|
||||
{
|
||||
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
|
||||
const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
|
||||
auto group_id_x = block_1d_id % GroupNum;
|
||||
auto group_id_y = block_1d_id / GroupNum;
|
||||
auto remap_block_1d_id =
|
||||
group_id_x <= big_group_num
|
||||
? group_id_x * group_size + group_id_y
|
||||
: group_id_x * group_size + big_group_num - group_id_x + group_id_y;
|
||||
|
||||
const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum);
|
||||
auto group_id = block_1d_id % GroupNum;
|
||||
auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum;
|
||||
index_t idx_N0 = remap_block_1d_id % N0;
|
||||
index_t idx_M0 = remap_block_1d_id / N0;
|
||||
|
||||
index_t idx_N0 = remap_block_1d_id % N0;
|
||||
index_t idx_M0 = remap_block_1d_id / N0;
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
|
||||
const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_;
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
|
||||
index_t idx_M00 = idx_M0 / M01_;
|
||||
index_t idx_M01 = idx_M0 % M01_;
|
||||
index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | - - 0 |/----> 2 | | | |
|
||||
* | | | / | | | | | M_0 MPerBlock
|
||||
* | M | /| | | | | |
|
||||
* |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
* | 1 | / | | | blockid | | |
|
||||
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
* | - V 1 | - 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | | | | |
|
||||
* | | | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* Example:
|
||||
* assume:
|
||||
* M0 = 5
|
||||
* N0 = 4
|
||||
* block_1d_id = 5
|
||||
* M01 = 2
|
||||
*
|
||||
* idx_N0 = 1
|
||||
* idx_M0 = 1
|
||||
* M01_adapt = 2
|
||||
* idx_M00 = 0
|
||||
* idx_M01 = 1
|
||||
* idx_N0_M01_local = 5
|
||||
* output {1, 2}
|
||||
*/
|
||||
|
||||
/**
|
||||
* idxN0
|
||||
*
|
||||
* |< mtx N >|
|
||||
*
|
||||
* NPerBlock NPerBlock NPerBlock NPerBlock
|
||||
* N_0 N_1 N_2 N_3
|
||||
* - |-----------|-----------|-----------|-----|-----|-
|
||||
* ^ | - - 0 |/----> 2 | | | |
|
||||
* | | | / | | | | | M_0 MPerBlock
|
||||
* | M | /| | | | | |
|
||||
* |-0---|---/-|-----|-----|-----------|-----|-----|-
|
||||
* | 1 | / | | | blockid | | |
|
||||
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
|
||||
* | - V 1 | - 3 | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* mtx M | | | | | |
|
||||
* | | | | | | M_2 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* | | | | | |
|
||||
* | | | | | | M_3 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* V | | | | | |
|
||||
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
|
||||
* | | | | | |
|
||||
* |-----------|-----------|-----------|-----|-----|-
|
||||
* Example:
|
||||
* assume:
|
||||
* M0 = 5
|
||||
* N0 = 4
|
||||
* block_1d_id = 5
|
||||
* M01 = 2
|
||||
*
|
||||
* idx_N0 = 1
|
||||
* idx_M0 = 1
|
||||
* M01_adapt = 2
|
||||
* idx_M00 = 0
|
||||
* idx_M01 = 1
|
||||
* idx_N0_M01_local = 5
|
||||
* output {1, 2}
|
||||
*/
|
||||
|
||||
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_,
|
||||
idx_N0_M01_local / M01_adapt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CTileIdx, typename CTileDim>
|
||||
@@ -405,15 +388,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, v
|
||||
index_t M01_;
|
||||
};
|
||||
|
||||
// keep the redundant type argument for backward compatibility
|
||||
template <index_t GroupNum, index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N>
|
||||
struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
: BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>
|
||||
{
|
||||
using BlockToCTileMap_Grouped_M00_N0_M01Adapt<GroupNum, MPerBlock, NPerBlock, void>::
|
||||
BlockToCTileMap_Grouped_M00_N0_M01Adapt;
|
||||
};
|
||||
|
||||
// columns of row-vectors
|
||||
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
|
||||
template <index_t MPerBlock, index_t NPerBlock, typename CGridDesc_M_N = void>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -202,15 +202,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
// maintain a container record is_src_valid, waiting for RunWrite use.
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
|
||||
// copy data from src_buf into src_vector_container
|
||||
auto src_vector_container = src_vector_type{
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
|
||||
auto src_vector_container =
|
||||
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
|
||||
|
||||
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
|
||||
using dst_vector_t = typename dst_vector_type::type;
|
||||
@@ -305,12 +307,78 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
#else
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
constexpr auto src_data_idx_seq = generate_sequence_v2(
|
||||
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
|
||||
|
||||
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
|
||||
|
||||
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<vector_t>(src_data_idx_seq);
|
||||
|
||||
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<bool>(src_data_idx_seq);
|
||||
|
||||
auto op_r_v = is_src_valid ? op_r : vector_t(0);
|
||||
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
|
||||
});
|
||||
|
||||
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
|
||||
// TODO make this logic more generic for more sub-dword datatype
|
||||
if constexpr(SrcVectorDim != DstVectorDim &&
|
||||
((is_same<half_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
|
||||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
// each transpose does
|
||||
@@ -386,6 +454,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
// if there is transpose, it's done here
|
||||
// if there is oob check, it's done here
|
||||
// TODO move this elsewhere
|
||||
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
|
||||
|
||||
@@ -738,6 +807,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
return make_naive_tensor_descriptor_packed(src_access_lengths);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstThreadScratchDescriptor()
|
||||
{
|
||||
// 1st stage of transforms
|
||||
@@ -789,6 +868,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
|
||||
private:
|
||||
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto src_oob_thread_scratch_desc_ =
|
||||
decltype(GetSrcThreadScratchDescriptor()){};
|
||||
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
|
||||
|
||||
using SrcThreadScratch =
|
||||
@@ -798,6 +879,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
decltype(src_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using SrcOOBThreadScratch =
|
||||
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
bool, // apply data_convert with SrcThreadScratch
|
||||
1,
|
||||
decltype(src_oob_thread_scratch_desc_),
|
||||
true>;
|
||||
|
||||
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
|
||||
DstData,
|
||||
DstScalarPerVector,
|
||||
@@ -805,6 +893,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
true>;
|
||||
|
||||
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
|
||||
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
|
||||
|
||||
DstThreadScratch dst_thread_scratch_;
|
||||
|
||||
|
||||
104
include/ck/utility/blkgemmpipe_scheduler.hpp
Normal file
104
include/ck/utility/blkgemmpipe_scheduler.hpp
Normal file
@@ -0,0 +1,104 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct BlockGemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
enum struct TailNumber
|
||||
{
|
||||
// Single / Double buffer pipeline
|
||||
Odd,
|
||||
Even,
|
||||
|
||||
// Long prefetch pipeline, up to 8
|
||||
One,
|
||||
Two,
|
||||
Three,
|
||||
Four,
|
||||
Five,
|
||||
Six,
|
||||
Seven,
|
||||
|
||||
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
||||
Empty,
|
||||
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
||||
// prefetchstages
|
||||
Full,
|
||||
};
|
||||
template <index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t ABufferLoadWidth,
|
||||
index_t BBufferLoadWidth,
|
||||
index_t ALDSWriteWidth,
|
||||
index_t BLDSWriteWidth,
|
||||
index_t ALDSReadWidth,
|
||||
index_t BLDSReadWidth,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t KPerXDL>
|
||||
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
|
||||
{
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
|
||||
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
|
||||
|
||||
static constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
|
||||
static constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
|
||||
|
||||
static constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
|
||||
static constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
|
||||
|
||||
static constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
|
||||
static constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
|
||||
|
||||
static constexpr index_t C_MFMA_Inst_Num =
|
||||
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
static constexpr auto Print()
|
||||
{
|
||||
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
|
||||
BlockSize,
|
||||
WaveSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
KPerXDL);
|
||||
|
||||
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
|
||||
"%d, %d\n C MFMA inst: %d\n",
|
||||
A_Buffer_Load_Inst_Num,
|
||||
B_Buffer_Load_Inst_Num,
|
||||
A_LDS_Write_Inst_Num,
|
||||
B_LDS_Write_Inst_Num,
|
||||
A_LDS_Read_Inst_Num,
|
||||
B_LDS_Read_Inst_Num,
|
||||
C_MFMA_Inst_Num);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -163,6 +163,13 @@ struct scalar_type<bf8_t>
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<bool>
|
||||
{
|
||||
using type = bool;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
{
|
||||
|
||||
@@ -10,10 +10,12 @@ namespace ck {
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
@@ -162,4 +162,83 @@ struct transpose_vectors<int8_t, NX, NY>
|
||||
}
|
||||
};
|
||||
|
||||
// transpose f8 4x4
|
||||
__device__ void transpose_f8_4x4(const f8x4_t& x0,
|
||||
const f8x4_t& x1,
|
||||
const f8x4_t& x2,
|
||||
const f8x4_t& x3,
|
||||
f8x4_t& y0,
|
||||
f8x4_t& y1,
|
||||
f8x4_t& y2,
|
||||
f8x4_t& y3)
|
||||
{
|
||||
int32_t t0, t1;
|
||||
int32_t z0, z1, z2, z3;
|
||||
constexpr int32_t m0 = 0x05010400;
|
||||
constexpr int32_t m1 = 0x05040100;
|
||||
constexpr int32_t m2 = 0x07060302;
|
||||
constexpr int32_t m3 = 0x07030602;
|
||||
|
||||
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
|
||||
// -- -- -- -- -- -- -- -- - - - -
|
||||
// index 7 6 5 4 3 2 1 0 33 77 44 88
|
||||
// index is reversed because of little endianness (least significant bits first)
|
||||
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
|
||||
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
|
||||
z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
|
||||
z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
|
||||
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
|
||||
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
|
||||
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
|
||||
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
|
||||
|
||||
y0 = bit_cast<f8x4_t>(z0);
|
||||
y1 = bit_cast<f8x4_t>(z1);
|
||||
y2 = bit_cast<f8x4_t>(z2);
|
||||
y3 = bit_cast<f8x4_t>(z3);
|
||||
}
|
||||
|
||||
template <index_t NX, index_t NY>
|
||||
struct transpose_vectors<f8_t, NX, NY>
|
||||
{
|
||||
// we got [NY * NX] amount of S data to be transposed
|
||||
static constexpr index_t s_per_x = NY;
|
||||
static constexpr index_t s_per_y = NX;
|
||||
|
||||
using S = f8_t;
|
||||
using VX = vector_type<f8_t, s_per_x>;
|
||||
using VY = vector_type<f8_t, s_per_y>;
|
||||
|
||||
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
|
||||
StaticallyIndexedArray<VY&, NY>& vy_tuple)
|
||||
{
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
|
||||
|
||||
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
|
||||
static_for<0, NY, 4>{}([&](auto iy) {
|
||||
static_for<0, NX, 4>{}([&](auto ix) {
|
||||
// reference to 4 f8 data from vx_tuple
|
||||
const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
|
||||
const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
|
||||
const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
|
||||
const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
|
||||
|
||||
// reference to 4 f8 data from vy_tuple
|
||||
auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
|
||||
auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
|
||||
auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
|
||||
auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
|
||||
|
||||
// transpose
|
||||
transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -43,6 +43,8 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
|
||||
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
|
||||
Y y;
|
||||
|
||||
// auto t = reinterpret_cast<const Y*>(&x);
|
||||
// y = *t;
|
||||
__builtin_memcpy(&y, &x, sizeof(X));
|
||||
|
||||
return y;
|
||||
|
||||
@@ -21,7 +21,7 @@ template <typename ADataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
|
||||
@@ -0,0 +1,505 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmV2<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmV2<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -31,6 +31,18 @@ struct GeneratorTensor_1
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::half_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bhalf_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::half_t>(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::bhalf_t>
|
||||
{
|
||||
@@ -43,6 +55,20 @@ struct GeneratorTensor_1<ck::bhalf_t>
|
||||
}
|
||||
};
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
template <>
|
||||
struct GeneratorTensor_1<ck::f8_t>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ck::bhalf_t operator()(Is...)
|
||||
{
|
||||
return ck::type_convert<ck::f8_t>(value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GEMM_UNIVERSAL_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_UNIVERSAL_INSTANCES
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES})
|
||||
@@ -0,0 +1,91 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances<Intrawave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances<Intrawave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances<Interwave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances<Interwave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,98 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Compute friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
// AGPR Spill
|
||||
// DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
// AGPR Spill when use permuted lds layout. so, use padding for these two.
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances<Intrawave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances<Intrawave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances<Interwave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances<Interwave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = f8_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances<Intrawave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances<Intrawave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances<Interwave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances<Interwave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = f8_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Compute friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 16, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 16, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 16, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 16, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 16, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances<Intrawave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances<Intrawave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances<Interwave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances<Interwave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = f8_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 64, 16, 8, 32, 32, 3, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 2, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 16, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 16, 2, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 16, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 16, 2, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 2, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances<Intrawave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances<Intrawave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances<Interwave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances<Interwave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = f8_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Compute friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 16, 8, 16, 16, 8, 7, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
// Latency friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 16, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 16, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 16, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 16, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 16, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 16, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 16, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances<GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances<GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances<GemmMNPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances<Intrawave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances<Intrawave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances<Intrawave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances<Interwave, GemmDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances<Interwave, GemmKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances<Interwave, GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
299
profiler/include/profiler/profile_gemm_universal_impl.hpp
Normal file
299
profiler/include/profiler/profile_gemm_universal_impl.hpp
Normal file
@@ -0,0 +1,299 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_universal.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool profile_gemm_universal_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int StrideA,
|
||||
int StrideB,
|
||||
int StrideC,
|
||||
int KBatch,
|
||||
int n_warmup,
|
||||
int n_iter)
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmV2<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
// Run reference GEMM
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_kbatch = 0;
|
||||
|
||||
// profile device GEMM instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::vector<int> kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38};
|
||||
|
||||
if(KBatch > 0)
|
||||
{
|
||||
kbatch_list = {KBatch};
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < kbatch_list.size(); i++)
|
||||
{
|
||||
auto kbatch_curr = kbatch_list[i];
|
||||
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
kbatch_curr,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
|
||||
// re-init C to zero before profiling next kernel
|
||||
c_device_buf.SetZero();
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(),
|
||||
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time = invoker_ptr->Run(
|
||||
argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
|
||||
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
|
||||
<< kbatch_curr << std::endl;
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
// set softer tolerances for fp8
|
||||
if constexpr(is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
|
||||
is_same_v<CDataType, f8_t>)
|
||||
{
|
||||
std::string msg = "Error: Incorrect results!";
|
||||
double rtol = 1e-1;
|
||||
double atol = 1e-1;
|
||||
pass = pass & ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, msg, rtol, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#endif
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_kbatch = kbatch_curr;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<CDataType, float>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f32";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, half_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = f16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, bhalf_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = bf16";
|
||||
}
|
||||
else if constexpr(is_same<CDataType, int8_t>::value)
|
||||
{
|
||||
std::cout << "Best Perf for datatype = int8";
|
||||
}
|
||||
|
||||
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " ALayout = ColumnMajor";
|
||||
}
|
||||
|
||||
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = RowMajor";
|
||||
}
|
||||
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
|
||||
{
|
||||
std::cout << " BLayout = ColumnMajor";
|
||||
}
|
||||
|
||||
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
|
||||
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
|
||||
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
|
||||
<< " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -49,6 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp)
|
||||
@@ -115,6 +116,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user