mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
This commit is contained in:
@@ -1,16 +1,20 @@
|
|||||||
ARG BASE_DOCKER="rocm/pytorch:latest"
|
ARG BASE_DOCKER="rocm/pytorch:latest"
|
||||||
FROM $BASE_DOCKER
|
FROM $BASE_DOCKER
|
||||||
RUN groupadd -f render && \
|
ARG AITER_BRANCH="main"
|
||||||
|
ARG CK_AITER_BRANCH="develop"
|
||||||
|
RUN groupadd -g 109 render && \
|
||||||
|
usermod -u 1001 jenkins && \
|
||||||
|
groupmod -g 1001 jenkins && \
|
||||||
pip install pandas zmq einops && \
|
pip install pandas zmq einops && \
|
||||||
pip install numpy==1.26.2 && \
|
pip install numpy==1.26.2 && \
|
||||||
sudo mkdir /home/jenkins && \
|
sudo mkdir /home/jenkins && \
|
||||||
sudo mkdir /home/jenkins/workspace && \
|
sudo mkdir /home/jenkins/workspace && \
|
||||||
cd /home/jenkins/workspace && \
|
cd /home/jenkins/workspace && \
|
||||||
rm -rf aiter && \
|
rm -rf aiter && \
|
||||||
git clone --recursive https://github.com/ROCm/aiter.git && \
|
git clone -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
|
||||||
cd aiter && \
|
cd aiter && \
|
||||||
rm -rf 3rdparty/composable_kernel/ && \
|
rm -rf 3rdparty/composable_kernel/ && \
|
||||||
git clone https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
|
git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
|
||||||
python3 setup.py develop && \
|
python3 setup.py develop && \
|
||||||
chown -R jenkins:jenkins /home/jenkins/workspace && \
|
chown -R jenkins:jenkins /home/jenkins/workspace && \
|
||||||
chmod -R a+rwx /home/jenkins/workspace && \
|
chmod -R a+rwx /home/jenkins/workspace && \
|
||||||
|
|||||||
164
Jenkinsfile
vendored
164
Jenkinsfile
vendored
@@ -190,7 +190,7 @@ def buildDocker(install_prefix){
|
|||||||
}
|
}
|
||||||
else if(params.RUN_AITER_TESTS){
|
else if(params.RUN_AITER_TESTS){
|
||||||
image_name = "rocm/composable_kernel:ck_aiter"
|
image_name = "rocm/composable_kernel:ck_aiter"
|
||||||
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter . "
|
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
dockerArgs = dockerArgs + " -f Dockerfile . "
|
dockerArgs = dockerArgs + " -f Dockerfile . "
|
||||||
@@ -438,34 +438,6 @@ def cmake_build(Map conf=[:]){
|
|||||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
|
|
||||||
try{
|
|
||||||
archiveArtifacts "perf_transpose_*.log"
|
|
||||||
if (arch_type == 1){
|
|
||||||
stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a"
|
|
||||||
}
|
|
||||||
else if (arch_type == 2){
|
|
||||||
stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch(Exception err){
|
|
||||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
|
||||||
try{
|
|
||||||
archiveArtifacts "perf_tile_gemm_**.log"
|
|
||||||
if (arch == 1){
|
|
||||||
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
|
|
||||||
}
|
|
||||||
else if (arch == 2){
|
|
||||||
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
catch(Exception err){
|
|
||||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def buildHipClangJob(Map conf=[:]){
|
def buildHipClangJob(Map conf=[:]){
|
||||||
@@ -762,24 +734,6 @@ def process_results(Map conf=[:]){
|
|||||||
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
|
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
|
|
||||||
try{
|
|
||||||
unstash "perf_transpose_log_gfx942"
|
|
||||||
unstash "perf_transpose_log_gfx90a"
|
|
||||||
}
|
|
||||||
catch(Exception err){
|
|
||||||
echo "could not locate the Transpose performance logs: ${err.getMessage()}."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
|
||||||
try{
|
|
||||||
unstash "perf_tile_gemm_log_gfx942"
|
|
||||||
unstash "perf_tile_gemm_log_gfx90a"
|
|
||||||
}
|
|
||||||
catch(Exception err){
|
|
||||||
echo "could not locate the GEMM performance logs: ${err.getMessage()}."
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
|
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
|
||||||
// unstash deb packages
|
// unstash deb packages
|
||||||
unstash "packages"
|
unstash "packages"
|
||||||
@@ -843,10 +797,10 @@ def run_aiter_tests(Map conf=[:]){
|
|||||||
withDockerContainer(image: image, args: dockerOpts) {
|
withDockerContainer(image: image, args: dockerOpts) {
|
||||||
timeout(time: 45, unit: 'MINUTES'){
|
timeout(time: 45, unit: 'MINUTES'){
|
||||||
try{
|
try{
|
||||||
sh "python3 --version"
|
|
||||||
sh "rocminfo"
|
sh "rocminfo"
|
||||||
sh "python3 ../aiter/op_tests/test_gemm_a8w8_blockscale.py"
|
sh "python3 --version"
|
||||||
//sh "python3 ../aiter/op_tests/test_mha.py"
|
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
|
||||||
|
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
|
||||||
}
|
}
|
||||||
catch(e){
|
catch(e){
|
||||||
echo "Throwing error exception while running AITER tests"
|
echo "Throwing error exception while running AITER tests"
|
||||||
@@ -861,7 +815,7 @@ def run_aiter_tests(Map conf=[:]){
|
|||||||
}
|
}
|
||||||
|
|
||||||
//launch develop branch daily jobs
|
//launch develop branch daily jobs
|
||||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||||
@@ -941,14 +895,6 @@ pipeline {
|
|||||||
name: "RUN_CK_TILE_FMHA_TESTS",
|
name: "RUN_CK_TILE_FMHA_TESTS",
|
||||||
defaultValue: false,
|
defaultValue: false,
|
||||||
description: "Run the ck_tile FMHA tests (default: OFF)")
|
description: "Run the ck_tile FMHA tests (default: OFF)")
|
||||||
booleanParam(
|
|
||||||
name: "RUN_CK_TILE_TRANSPOSE_TESTS",
|
|
||||||
defaultValue: false,
|
|
||||||
description: "Run the ck_tile Transpose tests (default: OFF)")
|
|
||||||
booleanParam(
|
|
||||||
name: "RUN_CK_TILE_GEMM_TESTS",
|
|
||||||
defaultValue: false,
|
|
||||||
description: "Run the ck_tile GEMM tests (default: OFF)")
|
|
||||||
booleanParam(
|
booleanParam(
|
||||||
name: "RUN_TILE_ENGINE_GEMM_TESTS",
|
name: "RUN_TILE_ENGINE_GEMM_TESTS",
|
||||||
defaultValue: false,
|
defaultValue: false,
|
||||||
@@ -1009,6 +955,14 @@ pipeline {
|
|||||||
name: "RUN_AITER_TESTS",
|
name: "RUN_AITER_TESTS",
|
||||||
defaultValue: false,
|
defaultValue: false,
|
||||||
description: "Run AITER tests with latest CK develop branch (default: OFF)")
|
description: "Run AITER tests with latest CK develop branch (default: OFF)")
|
||||||
|
string(
|
||||||
|
name: 'aiter_branch',
|
||||||
|
defaultValue: 'main',
|
||||||
|
description: 'Specify which branch of AITER to use (default: main)')
|
||||||
|
string(
|
||||||
|
name: 'ck_aiter_branch',
|
||||||
|
defaultValue: 'develop',
|
||||||
|
description: 'Specify which branch of CK to test with AITER (default: develop)')
|
||||||
}
|
}
|
||||||
environment{
|
environment{
|
||||||
dbuser = "${dbuser}"
|
dbuser = "${dbuser}"
|
||||||
@@ -1093,13 +1047,13 @@ pipeline {
|
|||||||
{
|
{
|
||||||
parallel
|
parallel
|
||||||
{
|
{
|
||||||
stage("Run AITER Tests on gfx90a")
|
stage("Run AITER Tests on gfx942")
|
||||||
{
|
{
|
||||||
when {
|
when {
|
||||||
beforeAgent true
|
beforeAgent true
|
||||||
expression { params.RUN_AITER_TESTS.toBoolean() }
|
expression { params.RUN_AITER_TESTS.toBoolean() }
|
||||||
}
|
}
|
||||||
agent{ label rocmnode("gfx90a")}
|
agent{ label rocmnode("gfx942")}
|
||||||
steps{
|
steps{
|
||||||
run_aiter_tests()
|
run_aiter_tests()
|
||||||
cleanWs()
|
cleanWs()
|
||||||
@@ -1198,94 +1152,6 @@ pipeline {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stage("Run CK_TILE_TRANSPOSE Tests")
|
|
||||||
{
|
|
||||||
parallel
|
|
||||||
{
|
|
||||||
stage("Run CK_TILE_TRANSPOSE Tests on gfx90a")
|
|
||||||
{
|
|
||||||
when {
|
|
||||||
beforeAgent true
|
|
||||||
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
|
|
||||||
}
|
|
||||||
agent{ label rocmnode("gfx90a") }
|
|
||||||
environment{
|
|
||||||
setup_args = "NO_CK_BUILD"
|
|
||||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
|
||||||
make -j64 tile_example_batched_transpose && \
|
|
||||||
cd ../ &&
|
|
||||||
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
|
||||||
}
|
|
||||||
steps{
|
|
||||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
|
||||||
cleanWs()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage("Run CK_TILE_TRANSPOSE Tests on gfx942")
|
|
||||||
{
|
|
||||||
when {
|
|
||||||
beforeAgent true
|
|
||||||
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
|
|
||||||
}
|
|
||||||
agent{ label rocmnode("gfx942") }
|
|
||||||
environment{
|
|
||||||
setup_args = "NO_CK_BUILD"
|
|
||||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
|
|
||||||
make -j64 tile_example_batched_transpose && \
|
|
||||||
cd ../ &&
|
|
||||||
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
|
||||||
}
|
|
||||||
steps{
|
|
||||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
|
||||||
cleanWs()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage("Run CK_TILE_GEMM Tests")
|
|
||||||
{
|
|
||||||
parallel
|
|
||||||
{
|
|
||||||
stage("Run CK_TILE_GEMM Tests on gfx90a")
|
|
||||||
{
|
|
||||||
when {
|
|
||||||
beforeAgent true
|
|
||||||
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
|
|
||||||
}
|
|
||||||
agent{ label rocmnode("gfx90a") }
|
|
||||||
environment{
|
|
||||||
setup_args = "NO_CK_BUILD"
|
|
||||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
|
||||||
make -j64 tile_example_gemm_universal && \
|
|
||||||
cd ../ &&
|
|
||||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
|
||||||
}
|
|
||||||
steps{
|
|
||||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
|
||||||
cleanWs()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage("Run CK_TILE_GEMM Tests on gfx942")
|
|
||||||
{
|
|
||||||
when {
|
|
||||||
beforeAgent true
|
|
||||||
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
|
|
||||||
}
|
|
||||||
agent{ label rocmnode("gfx942") }
|
|
||||||
environment{
|
|
||||||
setup_args = "NO_CK_BUILD"
|
|
||||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
|
|
||||||
make -j64 tile_example_gemm_universal && \
|
|
||||||
cd ../ &&
|
|
||||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
|
||||||
}
|
|
||||||
steps{
|
|
||||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
|
||||||
cleanWs()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stage("Run TILE_ENGINE_GEMM Tests")
|
stage("Run TILE_ENGINE_GEMM Tests")
|
||||||
{
|
{
|
||||||
parallel
|
parallel
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
|||||||
4. Build the entire CK library:
|
4. Build the entire CK library:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
make -j
|
make -j"$(nproc)"
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Install CK:
|
5. Install CK:
|
||||||
@@ -213,4 +213,4 @@ script/uninstall_precommit.sh
|
|||||||
```
|
```
|
||||||
|
|
||||||
If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the
|
If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the
|
||||||
`git commit` command.
|
`git commit` command.
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ struct Add
|
|||||||
|
|
||||||
return type_convert<T>(y_ + x_);
|
return type_convert<T>(y_ + x_);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr bool requires_special_combine = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SquareAdd
|
struct SquareAdd
|
||||||
@@ -64,28 +62,6 @@ struct SquareAdd
|
|||||||
float x_ = type_convert<float>(x);
|
float x_ = type_convert<float>(x);
|
||||||
return type_convert<T>(y_ + (x_ * x_));
|
return type_convert<T>(y_ + (x_ * x_));
|
||||||
}
|
}
|
||||||
|
|
||||||
// For combining partial results
|
|
||||||
template <typename T,
|
|
||||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
|
||||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
|
||||||
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(const T& partial1,
|
|
||||||
const T& partial2) const
|
|
||||||
{
|
|
||||||
return partial1 + partial2; // Just add the partial sums, don't square again
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
|
||||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
|
||||||
CK_TILE_HOST_DEVICE constexpr T combine_partial_results(T& partial1, T& partial2) const
|
|
||||||
{
|
|
||||||
float partial1_ = type_convert<float>(partial1);
|
|
||||||
float partial2_ = type_convert<float>(partial2);
|
|
||||||
return type_convert<T>(partial1_ + partial2_);
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr bool requires_special_combine = true;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Max
|
struct Max
|
||||||
@@ -109,8 +85,6 @@ struct Max
|
|||||||
{
|
{
|
||||||
return max(y, x);
|
return max(y, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr bool requires_special_combine = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AbsMax
|
struct AbsMax
|
||||||
@@ -134,8 +108,6 @@ struct AbsMax
|
|||||||
{
|
{
|
||||||
return max(y, abs(x));
|
return max(y, abs(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr bool requires_special_combine = false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ReduceOp
|
} // namespace ReduceOp
|
||||||
|
|||||||
@@ -183,16 +183,7 @@ struct BlockReduce2dSync
|
|||||||
|
|
||||||
// pull data from remote lane
|
// pull data from remote lane
|
||||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||||
|
v_local = reduce_func(v_local, v_remote);
|
||||||
// For reduce, use combine_partial_results for operations that require it
|
|
||||||
if constexpr(ReduceFunc::requires_special_combine)
|
|
||||||
{
|
|
||||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
v_local = reduce_func(v_local, v_remote);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -309,16 +300,7 @@ struct BlockReduce2dCrossWarpSync
|
|||||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||||
|
v_local = reduce_func(v_local, v_remote);
|
||||||
// For reduce, use combine_partial_results for operations that require it
|
|
||||||
if constexpr(ReduceFunc::requires_special_combine)
|
|
||||||
{
|
|
||||||
v_local = reduce_func.combine_partial_results(v_local, v_remote);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
v_local = reduce_func(v_local, v_remote);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||||
|
|||||||
@@ -189,7 +189,9 @@ struct Reduce
|
|||||||
/// @note Requirements:
|
/// @note Requirements:
|
||||||
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
|
/// - y_continous_dim % ThreadTile_N == 0 (for proper thread distribution)
|
||||||
/// - input_strides[-1] == 1 (for contiguous memory access)
|
/// - input_strides[-1] == 1 (for contiguous memory access)
|
||||||
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim, auto input_strides)
|
template <typename InputStrides>
|
||||||
|
CK_TILE_HOST static bool IsSupportedArgument(index_t y_continous_dim,
|
||||||
|
InputStrides input_strides)
|
||||||
{
|
{
|
||||||
using S = typename Problem::BlockShape;
|
using S = typename Problem::BlockShape;
|
||||||
|
|
||||||
|
|||||||
@@ -308,20 +308,8 @@ using TestConfig_F32_Max = std::tuple<float,
|
|||||||
Shape1_WarpTile,
|
Shape1_WarpTile,
|
||||||
Shape1_ThreadTile>;
|
Shape1_ThreadTile>;
|
||||||
|
|
||||||
using TestConfig_F32_SquareAdd = std::tuple<float,
|
using TestTypes = ::testing::
|
||||||
float,
|
Types<TestConfig_F32_Add, TestConfig_F16_Add, TestConfig_F32_CrossWarp, TestConfig_F32_Max>;
|
||||||
float,
|
|
||||||
ck_tile::ReduceOp::SquareAdd,
|
|
||||||
Shape1_BlockWarps,
|
|
||||||
Shape1_BlockTile,
|
|
||||||
Shape1_WarpTile,
|
|
||||||
Shape1_ThreadTile>;
|
|
||||||
|
|
||||||
using TestTypes = ::testing::Types<TestConfig_F32_Add,
|
|
||||||
TestConfig_F16_Add,
|
|
||||||
TestConfig_F32_CrossWarp,
|
|
||||||
TestConfig_F32_Max,
|
|
||||||
TestConfig_F32_SquareAdd>;
|
|
||||||
|
|
||||||
TYPED_TEST_SUITE(TestCkTileReduce, TestTypes);
|
TYPED_TEST_SUITE(TestCkTileReduce, TestTypes);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user