mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into update_cka8w8
This commit is contained in:
@@ -145,20 +145,20 @@ message("hip_version_flat=${hip_VERSION_FLAT}")
|
||||
|
||||
message("checking which targets are supported")
|
||||
#In order to build just the CK library (without tests and examples) for all supported GPU targets
|
||||
#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
#use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
|
||||
#
|
||||
#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures.
|
||||
if(NOT ENABLE_ASAN_PACKAGING)
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
|
||||
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
endif()
|
||||
else()
|
||||
#build CK only for xnack-supported targets when using ASAN
|
||||
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
|
||||
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+")
|
||||
endif()
|
||||
|
||||
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
|
||||
@@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
message("Enabling XDL instances")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
message("Enabling FP8 gemms in ckProfiler")
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message("Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
|
||||
@@ -221,7 +223,7 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
|
||||
endif()
|
||||
set(check-coerce)
|
||||
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
|
||||
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132 AND ${hip_VERSION_FLAT} LESS 600300000)
|
||||
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
message("Adding the amdgpu-coerce-illegal-types=1")
|
||||
add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1")
|
||||
endif()
|
||||
|
||||
@@ -24,10 +24,10 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \
|
||||
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
|
||||
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
|
||||
elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3.0.1-20.04-1_all.deb --no-check-certificate" && \
|
||||
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3.0.1-20.04-1_all.deb && \
|
||||
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3.0.1 rel-5 > /etc/apt/sources.list.d/rocm-build.list' && \
|
||||
amdgpu-repo --amdgpu-build=2033700; \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \
|
||||
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \
|
||||
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \
|
||||
amdgpu-repo --amdgpu-build=2074281; \
|
||||
fi
|
||||
|
||||
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
|
||||
|
||||
6
Jenkinsfile
vendored
6
Jenkinsfile
vendored
@@ -1101,11 +1101,11 @@ pipeline {
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx942" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
|
||||
-DGPU_TARGETS="gfx908;gfx90a;gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
@@ -1165,7 +1165,7 @@ pipeline {
|
||||
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
}
|
||||
steps{
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.8.3
|
||||
rocm-docs-core==1.8.4
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -103,7 +103,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.8.3
|
||||
rocm-docs-core==1.8.4
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via pybtex
|
||||
|
||||
@@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector<NDimSpatial>::RLayout;
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
|
||||
@@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
|
||||
Tensor<EDataType> conv_output_device(conv_output_g_n_k_wos_desc);
|
||||
Tensor<R0DataType> r0_device(r0_desc);
|
||||
|
||||
std::cout << "input: " << conv_input.mDesc << std::endl;
|
||||
std::cout << "weight: " << conv_weight.mDesc << std::endl;
|
||||
std::cout << "output: " << conv_output_device.mDesc << std::endl;
|
||||
std::cout << "reduction: " << r0_device.mDesc << std::endl << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-8, 7}(conv_weight);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-1, 1}(conv_weight);
|
||||
break;
|
||||
case 2:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-8, 7}(conv_input);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1, 1}(conv_weight);
|
||||
break;
|
||||
default:
|
||||
ck::utils::FillUniformDistribution<ADataType>{-5, 5}(conv_input);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-5, 5}(conv_weight);
|
||||
ck::utils::FillUniformDistribution<ADataType>{-8, 7}(conv_input);
|
||||
ck::utils::FillUniformDistribution<BDataType>{-1, 1}(conv_weight);
|
||||
}
|
||||
|
||||
DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize());
|
||||
@@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
|
||||
return false;
|
||||
}
|
||||
|
||||
// XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0.
|
||||
r0_device_buf.SetValue(ck::NumericLimits<R0DataType>::Lowest());
|
||||
|
||||
const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
const std::size_t flop = problem_size.GetFlops();
|
||||
const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>();
|
||||
if(config.time_kernel)
|
||||
{
|
||||
const std::size_t flop = problem_size.GetFlops();
|
||||
const std::size_t num_btype = problem_size.GetByte<ADataType, BDataType, EDataType>();
|
||||
|
||||
const float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
const float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< conv.GetTypeString() << std::endl;
|
||||
const float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
const float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << conv.GetTypeString() << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "FINISHED: " << conv.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
@@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
|
||||
BElementOp{},
|
||||
PassThrough{});
|
||||
|
||||
std::cout << "\nRunning verification on CPU." << std::endl;
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
Tensor<R0DataType> r0_host(r0_device.mDesc);
|
||||
@@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size,
|
||||
conv_output_device_buf.FromDevice(conv_output_device.mData.data());
|
||||
r0_device_buf.FromDevice(r0_device.mData.data());
|
||||
|
||||
return ck::utils::check_err(conv_output_device,
|
||||
conv_output_host,
|
||||
"Error: incorrect results! (Matrix E)",
|
||||
1e-5f,
|
||||
1e-4f) &&
|
||||
ck::utils::check_err(
|
||||
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f);
|
||||
auto pass = ck::utils::check_err(conv_output_device,
|
||||
conv_output_host,
|
||||
"Error: incorrect results! (Matrix E)",
|
||||
1e-3f,
|
||||
1e-3f);
|
||||
pass =
|
||||
pass && ck::utils::check_err(
|
||||
r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f);
|
||||
if(pass)
|
||||
std::cout << "Verification on CPU: PASS" << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -198,7 +198,7 @@ int main()
|
||||
throw std::runtime_error("wrong! this device_op instance does not support this problem");
|
||||
}
|
||||
|
||||
// init reducetion buffer to 0
|
||||
// init reduction buffer to 0
|
||||
r0_device_buf.SetZero();
|
||||
r1_device_buf.SetZero();
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
|
||||
|
||||
using DeviceReduceInstance =
|
||||
ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType,
|
||||
OutputDataType,
|
||||
ScaleDataType,
|
||||
OutputDataType,
|
||||
NumDim,
|
||||
NumDim,
|
||||
@@ -108,7 +108,8 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input,
|
||||
host_output_scaled_casted_transposed(m, k) = y1;
|
||||
const OutputDataType y_fabs =
|
||||
ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0)));
|
||||
host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0));
|
||||
host_output_amax(0) = ck::type_convert<OutputDataType>(ck::math::max(
|
||||
ck::type_convert<float>(y_fabs), ck::type_convert<float>(host_output_amax(0))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
|
||||
@@ -29,14 +29,14 @@ while getopts ":sa" opt; do
|
||||
done
|
||||
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS=(1)
|
||||
local PAGE_BLOCK_SIZE=(0)
|
||||
local CACHE_BATCH_IDX=(0)
|
||||
local NUM_SPLITS="1"
|
||||
local PAGE_BLOCK_SIZE="0"
|
||||
local CACHE_BATCH_IDX="0"
|
||||
|
||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
||||
NUM_SPLITS+=(2 3)
|
||||
PAGE_BLOCK_SIZE+=(128)
|
||||
CACHE_BATCH_IDX+=(1)
|
||||
NUM_SPLITS="$NUM_SPLITS 2 3"
|
||||
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
|
||||
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
|
||||
fi
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
@@ -47,9 +47,9 @@ run_fp16_bf16_tests() {
|
||||
for lse in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for num_splits in "${NUM_SPLITS[@]}" ; do
|
||||
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
|
||||
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
|
||||
for num_splits in $NUM_SPLITS ; do
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
fi
|
||||
|
||||
set +x
|
||||
set +x
|
||||
|
||||
@@ -57,6 +57,7 @@ template <typename XDataType_,
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
@@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
@@ -134,6 +136,7 @@ template <typename XDataType_,
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
@@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kFastFDiv_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
@@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
|
||||
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveMeanInvStd,
|
||||
Traits_::kFastFDiv,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
@@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
#include "layernorm2d_fwd_api_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep
|
||||
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
@@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
F_Vector_N : int
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd_ : bool
|
||||
F_kFastFDiv_ : bool
|
||||
F_kTwoPass_ : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
@@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
@@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
||||
|
||||
# rm rn tm tn vn pd mv 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]}
|
||||
# rm rn tm tn vn pd mv fdiv 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
|
||||
@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
|
||||
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
|
||||
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
@@ -54,11 +57,20 @@ template <typename InDataType,
|
||||
bool SaveMeanVar>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
|
||||
if(xr_stride < 0)
|
||||
xr_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
|
||||
if(yr_stride < 0)
|
||||
yr_stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
assert(stride >= n);
|
||||
assert(x_stride >= n);
|
||||
|
||||
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
|
||||
|
||||
@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({n});
|
||||
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
|
||||
@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
|
||||
@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
x_stride, // x row_stride
|
||||
xr_stride, // x residule row stride
|
||||
y_stride, // y row stride
|
||||
yr_stride}; // y residule row stride
|
||||
|
||||
float ave_time = layernorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto [rtol, atol] = get_elimit<InDataType>();
|
||||
|
||||
if(stride == n)
|
||||
if(x_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
|
||||
y_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
|
||||
y_host_ref.begin() + i_r * stride + n);
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
|
||||
y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
|
||||
y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &= ck_tile::check_err(y_host_dev_row,
|
||||
y_host_ref_row,
|
||||
std::string("OUT[") + std::to_string(i_r) +
|
||||
@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> y_residual_host_dev_row(
|
||||
y_residual_host_dev.begin() + i_r * stride,
|
||||
y_residual_host_dev.begin() + i_r * stride + n);
|
||||
y_residual_host_dev.begin() + i_r * yr_stride,
|
||||
y_residual_host_dev.begin() + i_r * yr_stride + n);
|
||||
std::vector<YResidualDataType> y_residual_host_ref_row(
|
||||
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n);
|
||||
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
|
||||
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
||||
y_residual_host_ref_row,
|
||||
std::string("ADD[") + std::to_string(i_r) +
|
||||
|
||||
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_basic -j
|
||||
# The memory bound pipeline on the gemm calculation
|
||||
make tile_example_gemm_mem_pipeline -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic`
|
||||
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = true;
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool kTilePermute = false;
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kPadM,
|
||||
kPadN,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy<ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
|
||||
@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = true;
|
||||
constexpr bool kPadM = true;
|
||||
constexpr bool kPadN = true;
|
||||
constexpr bool kPadK = true;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
8
example/ck_tile/13_moe_sorting/CMakeLists.txt
Normal file
8
example/ck_tile/13_moe_sorting/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
|
||||
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
|
||||
27
example/ck_tile/13_moe_sorting/README.md
Normal file
27
example/ck_tile/13_moe_sorting/README.md
Normal file
@@ -0,0 +1,27 @@
|
||||
# moe-sorting
|
||||
|
||||
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_sorting -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i index data type. (currently only fp32 supported now) (default:int32)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
|
||||
```
|
||||
223
example/ck_tile/13_moe_sorting/moe_sorting.cpp
Normal file
223
example/ck_tile/13_moe_sorting/moe_sorting.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "128", "number of input tokens")
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int num_experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int moe_buf_size = args.get_int("moe_buf_size");
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
if(topk > num_experts)
|
||||
{
|
||||
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
|
||||
topk,
|
||||
num_experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_expert_ids_dev(
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
sorted_ids_dev.FromDevice(sorted_ids_host.data());
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
|
||||
|
||||
int32_t ref_total_tokens_post_pad = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
|
||||
weights_host,
|
||||
sorted_ids_ref,
|
||||
sorted_weights_ref,
|
||||
sorted_expert_ids_ref,
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
sorted_weights_ref,
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host,
|
||||
sorted_expert_ids_ref,
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
if(moe_buf_size)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
73
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
Normal file
73
example/ck_tile/13_moe_sorting/moe_sorting_api.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
return -1;
|
||||
}
|
||||
if(a.moe_buf_bytes % 16)
|
||||
{
|
||||
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
|
||||
return -1;
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
|
||||
switch(smem_io_unroll_num)
|
||||
{
|
||||
case(1): {
|
||||
MOE_SORTING_DISPATCH(1);
|
||||
}
|
||||
case(2): {
|
||||
MOE_SORTING_DISPATCH(2);
|
||||
}
|
||||
case(3): {
|
||||
MOE_SORTING_DISPATCH(3);
|
||||
}
|
||||
case(5): {
|
||||
MOE_SORTING_DISPATCH(5);
|
||||
}
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(7): {
|
||||
MOE_SORTING_DISPATCH(7);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(9): {
|
||||
MOE_SORTING_DISPATCH(9);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
case(11): {
|
||||
MOE_SORTING_DISPATCH(11);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
20
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
Normal file
20
example/ck_tile/13_moe_sorting/moe_sorting_api.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/moe_sorting.hpp"
|
||||
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
Normal file
19
example/ck_tile/13_moe_sorting/script/smoke_test.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
$EXE -t=99 -e=120 -moe_buf_size=10244
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
|
||||
add_subdirectory(10_rmsnorm2d)
|
||||
add_subdirectory(11_add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(12_smoothquant)
|
||||
add_subdirectory(13_moe_sorting)
|
||||
|
||||
@@ -63,13 +63,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
#define __gfx101__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
|
||||
defined(__gfx10_3_generic__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
|
||||
@@ -381,10 +381,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
{
|
||||
tildes = {i_ztilde, i_ytilde, i_xtilde};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
@@ -750,6 +746,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
// check number of dimension, only implemented for 2D and 3D now
|
||||
if(NDimSpatial != 2 && NDimSpatial != 3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -93,12 +93,12 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
|
||||
@@ -60,12 +60,12 @@ __global__ void
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -117,12 +117,12 @@ __global__ void
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
// Pass two lds pointer is the key to tell compiler that ds_read/write
|
||||
// operate on different lds chunk at same time without order dependecy
|
||||
|
||||
@@ -98,12 +98,12 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t c_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
|
||||
@@ -60,12 +60,12 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
@@ -155,12 +155,12 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
|
||||
@@ -121,10 +121,10 @@ struct GridwiseTensorRearrange
|
||||
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
// Global Memory
|
||||
const index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
|
||||
const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
// TODO: Add arch limitation
|
||||
namespace ck {
|
||||
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
/********************************WAVE32 MODE***********************************************/
|
||||
@@ -260,7 +261,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
|
||||
// gfx12
|
||||
/********************************WAVE32 MODE***********************************************/
|
||||
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
|
||||
@@ -11,13 +11,15 @@
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__)
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
|
||||
defined(__gfx10_3_generic__)
|
||||
#define __gfx103__
|
||||
#endif
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
|
||||
defined(__gfx1103__) || defined(__gfx11_generic__)
|
||||
#define __gfx11__
|
||||
#endif
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__)
|
||||
#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
|
||||
#define __gfx12__
|
||||
#endif
|
||||
|
||||
|
||||
@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOT implemented
|
||||
static_assert(false, "The shuffle should always happen!");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
#include "ck_tile/host/reference/reference_permute.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
|
||||
|
||||
78
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
78
include/ck_tile/host/reference/reference_moe_sorting.hpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
HostTensor<IndexType>& p_sorted_token_ids,
|
||||
HostTensor<WeightType>& sorted_weight,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
std::vector<std::vector<IndexType>> expert_tokens(experts,
|
||||
std::vector<IndexType>(unit_size, num_token));
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
for(index_t k = 0; k < topk; k++)
|
||||
{
|
||||
IndexType e = topk_ids(t, k);
|
||||
WeightType w = weights(t, k);
|
||||
index_t idx = expert_slice_idxs[e];
|
||||
if(idx > expert_slices[e] * unit_size - 1)
|
||||
{
|
||||
expert_slices[e]++;
|
||||
index_t new_size = expert_slices[e] * unit_size;
|
||||
expert_tokens[e].resize(new_size);
|
||||
expert_token_weights[e].resize(new_size);
|
||||
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
|
||||
{
|
||||
expert_tokens[e][i] = num_token;
|
||||
expert_token_weights[e][i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
expert_tokens[e][idx] = t;
|
||||
expert_token_weights[e][idx] = w;
|
||||
expert_slice_idxs[e]++;
|
||||
}
|
||||
}
|
||||
|
||||
IndexType* out_tokens = p_sorted_token_ids.data();
|
||||
WeightType* out_weights = sorted_weight.data();
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
expert_token_weights[e].data(),
|
||||
sizeof(WeightType) * expert_slices[e] * unit_size);
|
||||
out_weights += expert_slices[e] * unit_size;
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
out_expert_id[s] = e;
|
||||
unit_cnt++;
|
||||
}
|
||||
out_expert_id += expert_slices[e];
|
||||
}
|
||||
unit_cnt *= unit_size;
|
||||
return;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -38,9 +38,7 @@ namespace ck_tile {
|
||||
template <typename BlockTile_, // block size, seq<M, N>
|
||||
typename WarpPerBlock_, // num warps along seq<M, N>
|
||||
typename WarpTile_, // warp size, seq<M, N>
|
||||
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
|
||||
index_t BlockSize_ =
|
||||
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
|
||||
typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
|
||||
struct Generic2dBlockShape
|
||||
{
|
||||
// block size
|
||||
@@ -68,10 +66,12 @@ struct Generic2dBlockShape
|
||||
static_assert(Warp_M % Vector_M == 0);
|
||||
static_assert(Warp_N % Vector_N == 0);
|
||||
// num of threads along seq<M, N>, within each warp
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M;
|
||||
static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N;
|
||||
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -230,7 +230,15 @@ struct PageBlockNavigator
|
||||
CK_TILE_HOST_DEVICE
|
||||
DataType* get_block_ptr(index_t block_index) const
|
||||
{
|
||||
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
|
||||
if(block_index < num_blocks)
|
||||
{
|
||||
return physical_blocks + physical_block_indices[block_index] * block_stride +
|
||||
fixed_offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
|
||||
|
||||
@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
static_assert(N0 != 0);
|
||||
|
||||
|
||||
232
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Normal file
232
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Normal file
@@ -0,0 +1,232 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MoeSortingHostArgs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t unit_size;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
struct MoeSortingKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using IndexType = typename Problem::IndexType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
typedef MoeSortingHostArgs MoeSortingKargs;
|
||||
|
||||
using Hargs = MoeSortingHostArgs;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
void* p_total_tokens_post_pad;
|
||||
void* p_moe_buf;
|
||||
index_t tokens;
|
||||
index_t num_experts;
|
||||
index_t moe_buf_bytes;
|
||||
|
||||
index_t tokens_per_thread;
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
// TODO: assume num-experts not too much
|
||||
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
|
||||
{
|
||||
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
|
||||
}
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
const auto blocks = BlockSize(h);
|
||||
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
k.p_moe_buf = h.p_moe_buf;
|
||||
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
|
||||
k.tokens = h.tokens;
|
||||
k.num_experts = h.num_experts;
|
||||
k.moe_buf_bytes = h.moe_buf_bytes;
|
||||
|
||||
const auto blocks = BlockSize(h);
|
||||
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
|
||||
{
|
||||
return row * total_col + col;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
|
||||
{
|
||||
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
|
||||
if(offset < buf_bytes / 16)
|
||||
{
|
||||
buf[offset] = uint8x16_t{0};
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
|
||||
const WeightType* __restrict__ weights,
|
||||
index_t* p_sorted_token_ids,
|
||||
WeightType* p_sorted_weights,
|
||||
index_t* p_sorted_expert_ids,
|
||||
index_t* p_total_tokens_post_pad,
|
||||
const index_t num_experts,
|
||||
const index_t tokens_per_thread,
|
||||
const index_t numel,
|
||||
const mdiv unit_size_mdiv,
|
||||
const mdiv topk_mdiv,
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t start_idx = tid * tokens_per_thread;
|
||||
|
||||
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
|
||||
|
||||
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
|
||||
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
|
||||
for(int i = 0; i < num_experts; ++i)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
|
||||
}
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(tid < num_experts)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, 0, tid)] = 0;
|
||||
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts, i, tid)] +=
|
||||
tokens_cnts[calc_index(num_experts, i - 1, tid)];
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
if(tid == 0)
|
||||
{
|
||||
cumsum[0] = 0;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
|
||||
unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
return max(y_, 1) * unit_size_mdiv.divisor;
|
||||
}();
|
||||
cumsum[i] = cumsum[i - 1] + current_units;
|
||||
}
|
||||
*p_total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
__syncthreads();
|
||||
if(tid < num_experts)
|
||||
{
|
||||
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
|
||||
{
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t rank_post_pad =
|
||||
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
|
||||
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
|
||||
}
|
||||
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
if(tid < num_experts)
|
||||
{
|
||||
index_t expert_offset =
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
|
||||
while(expert_offset < cumsum[tid + 1])
|
||||
{
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(blockIdx.x > 0)
|
||||
{
|
||||
if(kargs.p_moe_buf)
|
||||
{
|
||||
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
|
||||
kargs.moe_buf_bytes);
|
||||
}
|
||||
return;
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
static_cast<WeightType*>(kargs.p_sorted_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens_per_thread,
|
||||
numel,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
|
||||
// struct MoeSortingPipeline
|
||||
// {
|
||||
// // TODO: this kernel only support warp per row
|
||||
// using Problem = remove_cvref_t<Problem_>;
|
||||
// using Policy = remove_cvref_t<Policy_>;
|
||||
// using WeightType = typename Problem::WeightType;
|
||||
|
||||
// template <typename TopkIdWindow, typename WeightWindow>
|
||||
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
|
||||
// const WeightWindow& weight_window,
|
||||
// index_t* p_sorted_token_ids,
|
||||
// WeightType* p_sorted_weights,
|
||||
// index_t* p_sorted_expert_ids,
|
||||
// index_t* p_total_tokens_post_pad,
|
||||
// const index_t num_experts,
|
||||
// const index_t unit_size,
|
||||
// const size_t numel,
|
||||
// const index_t topk)
|
||||
// {
|
||||
// }
|
||||
// };
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MoeSortingPolicy
|
||||
{
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_>
|
||||
struct MoeSortingProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -115,12 +115,22 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
// somehow clang-format is splitting below line into multiple.
|
||||
// clang-format off
|
||||
sequence<false, GemmPipeline::kPadA>{});
|
||||
auto a_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
// clang-format on
|
||||
|
||||
auto a_block_window = make_tile_window(
|
||||
@@ -128,12 +138,22 @@ struct GemmKernel
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
// clang-format off
|
||||
sequence<false, GemmPipeline::kPadB>{});
|
||||
// clang-format on
|
||||
auto b_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence<GemmPipeline::kPadN, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
@@ -171,18 +191,28 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
// clang-format off
|
||||
sequence<false, GemmPipeline::kPadC>{});
|
||||
// clang-format on
|
||||
auto c_block_window = make_tile_window(
|
||||
auto c_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile);
|
||||
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
|
||||
static constexpr bool kPadA = Problem::kPadA;
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
// Where is the right place for HasHotLoop and TailNum ???
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
|
||||
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
|
||||
static constexpr bool kPadA = Problem::kPadA;
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window =
|
||||
make_tile_window(a_lds_block,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
a_copy_dram_window.get_tile_distribution());
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window =
|
||||
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// B LDS tile window for store
|
||||
auto b_copy_lds_window =
|
||||
make_tile_window(b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
auto b_copy_lds_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(b_shuffle_tmp, b_block_tile);
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
|
||||
// LDS write i + 1
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
|
||||
store_tile(b_copy_lds_window,
|
||||
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ namespace ck_tile {
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
|
||||
#if 0
|
||||
// 2d
|
||||
template <typename Problem>
|
||||
@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
return Problem::VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
return Problem::VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
#elif 1
|
||||
// fake XOR
|
||||
template <typename Problem>
|
||||
@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
#if 1 // coalesce reading for each blocks
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#else // coalesce reading for each warps
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = kMPerBlock / (M2 * M0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
#endif
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N0 = NPerBlock / N1;
|
||||
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (N2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t N1 = BlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
// coalesce reading for each warps
|
||||
else
|
||||
{
|
||||
constexpr index_t N0 = BlockSize / get_warp_size();
|
||||
constexpr index_t N1 = NPerBlock / (N2 * N0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
#if 1 // coalesce reading for each blocks
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemPackB<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
#else // coalesce reading for each warps
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = kNPerBlock / (N2 * N0);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
#endif
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size % (K2 * M0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -3,40 +3,133 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename TileGemmTraits_>
|
||||
struct GemmPipelineProblem
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kPadA = GemmTraits::kPadA;
|
||||
static constexpr bool kPadB = GemmTraits::kPadB;
|
||||
static constexpr bool kPadC = GemmTraits::kPadC;
|
||||
static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType);
|
||||
static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType);
|
||||
static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType);
|
||||
static constexpr bool kPadM = GemmTraits::kPadM;
|
||||
static constexpr bool kPadN = GemmTraits::kPadN;
|
||||
static constexpr bool kPadK = GemmTraits::kPadK;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
|
||||
constexpr index_t M0 = get_warp_size() / N2;
|
||||
constexpr index_t M1 = BlockGemmShape::kM / M0;
|
||||
|
||||
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = BlockGemmShape::kN / N0;
|
||||
|
||||
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentA();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentB();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeC = []() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentC();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
// Alias for GemmPipelineProblem
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename TileGemmTraits_>
|
||||
using GemmPipelineProblem =
|
||||
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
@@ -45,30 +138,15 @@ template <typename ADataType_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full>
|
||||
struct UniversalGemmPipelineProblem
|
||||
struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
|
||||
BDataType_,
|
||||
CDataType_,
|
||||
BlockGemmShape_,
|
||||
TileGemmTraits_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
|
||||
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadA = GemmTraits::kPadA;
|
||||
static constexpr bool kPadB = GemmTraits::kPadB;
|
||||
static constexpr bool kPadC = GemmTraits::kPadC;
|
||||
|
||||
static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1;
|
||||
static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1;
|
||||
static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -9,12 +9,8 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// UniversalGemm Policy
|
||||
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
|
||||
struct UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using LayoutA = remove_cvref_t<LayoutA_>;
|
||||
using LayoutB = remove_cvref_t<LayoutB_>;
|
||||
using LayoutC = remove_cvref_t<LayoutC_>;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
|
||||
static constexpr bool TransposeC = true;
|
||||
|
||||
template <typename Problem, typename DataType, index_t MNPerBlock>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
|
||||
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 &&
|
||||
sizeof(DataType) >= 4)
|
||||
{
|
||||
return (4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 &&
|
||||
sizeof(DataType) >= 2)
|
||||
{
|
||||
return (2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
|
||||
|
||||
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(ADataType);
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1),
|
||||
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1));
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto MLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<K0 * MLdsLayer>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_ak0_kMLdsLayer_m_ak1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
else // ColumnMajor A
|
||||
{
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
|
||||
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / WarpGemm::kM;
|
||||
constexpr auto K0PerThreadRead = K0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=mpair<=kN0
|
||||
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
|
||||
? 1
|
||||
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
|
||||
? M0
|
||||
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{},
|
||||
number<mpair>{},
|
||||
K1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(BDataType);
|
||||
;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1),
|
||||
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * NLdsLayer>{},
|
||||
number<NPerBlock / NLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
number<K0 * NLdsLayer>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
number<KPerBlock / KPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_kNLdsLayer_n_bk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
|
||||
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / WarpGemm::kN;
|
||||
constexpr auto K0PerThreadRead = K0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=kN0
|
||||
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
|
||||
? 1
|
||||
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
|
||||
? N0
|
||||
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{},
|
||||
number<npair>{},
|
||||
K1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * M0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N0 = NPerBlock / N1;
|
||||
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr index_t N1 = BlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (N2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t N1 = BlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
// coalesce reading for each warps
|
||||
else
|
||||
{
|
||||
constexpr index_t N0 = BlockSize / get_warp_size();
|
||||
constexpr index_t N1 = NPerBlock / (N2 * N0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size % (K2 * M0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N0 = NPerBlock / N1;
|
||||
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * N0);
|
||||
constexpr index_t K0 = BlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -3,19 +3,23 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadA_,
|
||||
bool kPadB_,
|
||||
bool kPadC_,
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
|
||||
@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
|
||||
// TODO: Extract some type to wrapper class
|
||||
@@ -93,7 +96,10 @@ struct Layernorm2dFwd
|
||||
|
||||
index_t m;
|
||||
index_t n;
|
||||
index_t stride; // row_stride
|
||||
index_t x_stride; // x row_stride
|
||||
index_t xr_stride; // x residule row stride
|
||||
index_t y_stride; // y row stride
|
||||
index_t yr_stride; // y residule row stride
|
||||
};
|
||||
using Hargs = Layernorm2dFwdHostArgs;
|
||||
|
||||
@@ -112,7 +118,10 @@ struct Layernorm2dFwd
|
||||
hargs.epsilon,
|
||||
hargs.m,
|
||||
hargs.n,
|
||||
hargs.stride};
|
||||
hargs.x_stride,
|
||||
hargs.xr_stride,
|
||||
hargs.y_stride,
|
||||
hargs.yr_stride};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
@@ -182,7 +191,7 @@ struct Layernorm2dFwd
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.x_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -201,7 +210,7 @@ struct Layernorm2dFwd
|
||||
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XResidualDataType*>(kargs.p_x_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.xr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -250,7 +259,7 @@ struct Layernorm2dFwd
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.y_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
@@ -266,7 +275,7 @@ struct Layernorm2dFwd
|
||||
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YResidualDataType*>(kargs.p_y_residual),
|
||||
make_tuple(kargs.m, kargs.n),
|
||||
make_tuple(kargs.stride, 1),
|
||||
make_tuple(kargs.yr_stride, 1),
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
|
||||
@@ -47,7 +47,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelford<P_>{};
|
||||
}
|
||||
@@ -57,7 +58,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelfordSync<P_>{};
|
||||
}
|
||||
@@ -67,7 +69,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
return BlockWelfordCrossWarpSync<P_>{};
|
||||
}
|
||||
@@ -79,7 +82,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
{
|
||||
using P_ = BlockWelfordProblem<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::Traits::kFastFDiv>;
|
||||
|
||||
using block_welford = BlockWelford<P_>;
|
||||
using x_block_tile =
|
||||
|
||||
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -120,12 +121,20 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
auto [mean, var] = block_welford(acc, cur_count, max_count);
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
|
||||
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) *
|
||||
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
|
||||
}
|
||||
},
|
||||
var);
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
|
||||
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
|
||||
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
|
||||
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
|
||||
|
||||
@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass
|
||||
|
||||
block_welford_sync(mean, var, cur_count);
|
||||
block_welford_cross_warp_sync(mean, var, cur_count, smem);
|
||||
block_tile_welford_post_scale_var(var, cur_count);
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
|
||||
// compute inv-std
|
||||
auto inv_std = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon));
|
||||
if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) *
|
||||
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
|
||||
}
|
||||
},
|
||||
var);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_window, cast_tile<MeanDataType>(mean));
|
||||
if constexpr(kSaveInvStd)
|
||||
|
||||
@@ -39,6 +39,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
|
||||
|
||||
template <bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kTwoPass_,
|
||||
Layernorm2dFusedAddEnum kFusedAdd_,
|
||||
Layernorm2dFusedQuantEnum kFusedQuant_>
|
||||
@@ -46,6 +47,7 @@ struct Layernorm2dFwdTraits
|
||||
{
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
|
||||
static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
|
||||
|
||||
11
include/ck_tile/ops/moe_sorting.hpp
Normal file
11
include/ck_tile/ops/moe_sorting.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
@@ -11,9 +11,10 @@ namespace ck_tile {
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelford
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using ComputeDataType = typename Problem::ComputeDataType;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockWelford() {}
|
||||
|
||||
@@ -46,8 +47,11 @@ struct BlockWelford
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
|
||||
welford_update(
|
||||
mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x, cur_count_);
|
||||
welford_update(mean_tensor(out_dstr_idx),
|
||||
var_tensor(out_dstr_idx),
|
||||
x,
|
||||
cur_count_,
|
||||
constant<kFastFDiv>{});
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -89,7 +93,8 @@ struct BlockWelford
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
|
||||
template <typename MeanDistributedTensor_, typename VarDistributedTensor_>
|
||||
CK_TILE_DEVICE void
|
||||
@@ -157,7 +162,8 @@ struct BlockWelfordSync
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count);
|
||||
v_remote_count,
|
||||
constant<kFastFDiv>{});
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -173,8 +179,9 @@ struct BlockWelfordSync
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockWelfordCrossWarpSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using BlockShape = typename Problem::BlockShape;
|
||||
static constexpr bool kFastFDiv = Problem::kFastFDiv;
|
||||
|
||||
template <typename MeanDistributedTensor_>
|
||||
CK_TILE_DEVICE static constexpr index_t GetReduceWarps()
|
||||
@@ -304,7 +311,8 @@ struct BlockWelfordCrossWarpSync
|
||||
v_local_count,
|
||||
v_remote_mean,
|
||||
v_remote_var,
|
||||
v_remote_count);
|
||||
v_remote_count,
|
||||
constant<kFastFDiv>{});
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
|
||||
@@ -351,12 +359,23 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_
|
||||
}
|
||||
|
||||
// Note: this function must be called after all the computation
|
||||
template <typename VarDistributedTensor_>
|
||||
template <typename VarDistributedTensor_, bool FastFdiv_ = false>
|
||||
CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor,
|
||||
int count)
|
||||
int count,
|
||||
bool_constant<FastFdiv_> = {})
|
||||
{
|
||||
using DataType = typename VarDistributedTensor_::DataType;
|
||||
tile_elementwise_inout([&count](auto& x) { x = x / type_convert<DataType>(count); },
|
||||
var_tensor);
|
||||
tile_elementwise_inout(
|
||||
[&count](auto& x) {
|
||||
if(FastFdiv_ && std::is_same_v<DataType, float>)
|
||||
{
|
||||
x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x / type_convert<DataType>(count);
|
||||
}
|
||||
},
|
||||
var_tensor);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,12 +7,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_, bool kFastFDiv_>
|
||||
struct BlockWelfordProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,25 +7,46 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
|
||||
template <typename T, bool kFastFDiv = false>
|
||||
CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count, bool_constant<kFastFDiv> = {})
|
||||
{
|
||||
// TODO: check nan? maybe no
|
||||
T delta = x - mean;
|
||||
mean += delta / count;
|
||||
if(kFastFDiv && std::is_same_v<T, float>)
|
||||
{
|
||||
mean += delta * __builtin_amdgcn_rcpf(count);
|
||||
}
|
||||
else
|
||||
{
|
||||
mean += delta / count;
|
||||
}
|
||||
T delta2 = x - mean;
|
||||
var += delta * delta2;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE static void
|
||||
welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b)
|
||||
template <typename T, bool kFastFDiv = false>
|
||||
CK_TILE_DEVICE static void welford_merge(T& mean_a,
|
||||
T& var_a,
|
||||
int& count_a,
|
||||
T mean_b,
|
||||
T var_b,
|
||||
int count_b,
|
||||
bool_constant<kFastFDiv> = {})
|
||||
{
|
||||
int count = count_a + count_b;
|
||||
T count_ = type_convert<T>(count);
|
||||
T count_a_ = type_convert<T>(count_a);
|
||||
T count_b_ = type_convert<T>(count_b);
|
||||
T count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
|
||||
int count = count_a + count_b;
|
||||
T count_ = type_convert<T>(count);
|
||||
T count_a_ = type_convert<T>(count_a);
|
||||
T count_b_ = type_convert<T>(count_b);
|
||||
T count_b_over_count;
|
||||
if(kFastFDiv && std::is_same_v<T, float>)
|
||||
{
|
||||
count_b_over_count =
|
||||
count == 0 ? type_convert<T>(0) : count_b_ * __builtin_amdgcn_rcpf(count_);
|
||||
}
|
||||
else
|
||||
{
|
||||
count_b_over_count = count == 0 ? type_convert<T>(0) : count_b_ / count_;
|
||||
}
|
||||
|
||||
T delta = mean_b - mean_a;
|
||||
mean_a += delta * count_b_over_count;
|
||||
|
||||
@@ -39,7 +39,25 @@ template <ck::index_t NDimSpatial,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple<
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
@@ -64,7 +82,25 @@ template <ck::index_t NDimSpatial,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances = std::tuple<
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
@@ -82,6 +118,24 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances = st
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// NGCHW requires transpose, we use vector loads and stores params for them
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
@@ -122,6 +176,24 @@ using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_instances
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| |
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
|
||||
@@ -352,6 +352,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances(
|
||||
@@ -375,6 +377,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instances(
|
||||
@@ -390,6 +394,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances(
|
||||
@@ -403,6 +409,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances(
|
||||
@@ -464,6 +472,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances(
|
||||
@@ -487,6 +497,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instances(
|
||||
@@ -511,6 +523,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances(
|
||||
@@ -524,6 +538,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances(
|
||||
|
||||
@@ -113,6 +113,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -136,6 +148,19 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
@@ -173,6 +198,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -196,6 +233,19 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
@@ -298,6 +348,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -321,6 +383,19 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
@@ -358,6 +433,18 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -381,6 +468,19 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
|
||||
@@ -24,7 +24,7 @@ namespace ck {
|
||||
namespace utils {
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_relative_threshold(const int numberOfAccumulations = 1)
|
||||
double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
@@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1)
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
|
||||
acc_error = std::pow(2, -NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
|
||||
double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1)
|
||||
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
|
||||
{
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
@@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA
|
||||
else
|
||||
{
|
||||
acc_error =
|
||||
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * numberOfAccumulations;
|
||||
std::pow(2, expo - NumericUtils<AccDataType>::mant) * 0.5 * number_of_accumulations;
|
||||
}
|
||||
return std::max(acc_error, midway_error);
|
||||
}
|
||||
|
||||
@@ -88,19 +88,19 @@ function(add_instance_library INSTANCE_NAME)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
if(source MATCHES "_xdl")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(source MATCHES "_wmma")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
elseif(source MATCHES "mha")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
if(source MATCHES "gemm_multiply_multiply_f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
endif()
|
||||
set(offload_targets)
|
||||
|
||||
@@ -15,6 +15,10 @@ set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
|
||||
2,
|
||||
NGCHW,
|
||||
GKYXC,
|
||||
NGKHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
|
||||
@@ -15,6 +15,10 @@ set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp
|
||||
)
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances<
|
||||
device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances<
|
||||
3,
|
||||
NGCDHW,
|
||||
GKZYXC,
|
||||
NGKDHW,
|
||||
ConvBwdWeightDefault,
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion::v1>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
{
|
||||
out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data());
|
||||
|
||||
auto number_of_accumulations = 1;
|
||||
static_assert(
|
||||
ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX,
|
||||
"Warning: Unhandled ReduceOpId for setting up the number of accumulations!");
|
||||
|
||||
if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG)
|
||||
{
|
||||
for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i)
|
||||
{
|
||||
number_of_accumulations *= kernel_params.window_spatial_lengths.at(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto absolute_error_threshold = 1.0;
|
||||
switch(in_params.init_method)
|
||||
{
|
||||
@@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
|
||||
|
||||
absolute_error_threshold =
|
||||
ck::utils::get_absolute_threshold<ComputeDataType, OutDataType>(
|
||||
absolute_error_threshold);
|
||||
absolute_error_threshold, number_of_accumulations);
|
||||
auto relative_error_threshold =
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>();
|
||||
ck::utils::get_relative_threshold<ComputeDataType, OutDataType>(
|
||||
number_of_accumulations);
|
||||
|
||||
bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData,
|
||||
out_n_c_do_ho_wo_host.mData,
|
||||
|
||||
@@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
@@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{});
|
||||
@@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
|
||||
|
||||
@@ -85,7 +85,7 @@ int profile_layernorm(int argc, char* argv[])
|
||||
|
||||
if(data_type == ck::DataTypeEnum::Half)
|
||||
{
|
||||
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F32, false, rank>(
|
||||
ck::profiler::profile_layernorm_impl<F16, F16, F16, F32, F16, F16, false, rank>(
|
||||
do_verification, init_method, do_log, time_kernel, length);
|
||||
}
|
||||
else if(data_type == ck::DataTypeEnum::Float)
|
||||
|
||||
@@ -133,12 +133,12 @@ def parse_logfile(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
res.append(lst[4])
|
||||
elif 'onnx_gemm' in logfile or 'mixed_gemm' in logfile:
|
||||
elif 'onnx_gemm' in logfile:
|
||||
for line in open(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
res.append(lst[33])
|
||||
elif 'splitK_gemm' in logfile:
|
||||
elif 'splitK_gemm' in logfile or 'mixed_gemm' in logfile:
|
||||
for line in open(logfile):
|
||||
if 'Best Perf' in line:
|
||||
lst=line.split()
|
||||
|
||||
@@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log
|
||||
python3 process_perf_data.py perf_reduction.log
|
||||
python3 process_perf_data.py perf_splitK_gemm.log
|
||||
python3 process_perf_data.py perf_onnx_gemm.log
|
||||
python3 process_perf_data.py perf_mixed_gemm.log
|
||||
|
||||
file=./perf_fmha_fwd_gfx942.log
|
||||
if [ -e "$file" ]; then
|
||||
|
||||
@@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
elseif(ARGN MATCHES "_smfmac")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
@@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME)
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
elseif(ARGN MATCHES "_smfmac")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
|
||||
@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = true;
|
||||
constexpr bool kPadM = true;
|
||||
constexpr bool kPadN = true;
|
||||
constexpr bool kPadK = true;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Lunching kernel with args:"
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK
|
||||
using KernelTypes_MK_KN = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
@@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types<
|
||||
using KernelTypes_MK_NK = ::testing::Types<
|
||||
// ADataType, BDataType, ComputeDataType, CDataType
|
||||
std::tuple< F16, F16, F16, F16>,
|
||||
#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94))
|
||||
std::tuple< F16, F8, F16, F16>,
|
||||
std::tuple< F8, F16, F16, F16>,
|
||||
std::tuple< F8, F8, F8, BF16>,
|
||||
|
||||
Reference in New Issue
Block a user