mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit '80ce6a573b4bb37c17c20eaac4fab48666be4edb' into develop
This commit is contained in:
@@ -234,6 +234,10 @@ endif()
|
||||
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
|
||||
set(CK_TILE_USE_WMMA 0)
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
|
||||
add_definitions(-DCK_GFX1030_SUPPORT)
|
||||
endif()
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
|
||||
@@ -23,3 +23,4 @@ add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(gemm_block_scale)
|
||||
add_subdirectory(utility)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(atomic_add_op)
|
||||
|
||||
2
test/ck_tile/atomic_add_op/CMakeLists.txt
Normal file
2
test/ck_tile/atomic_add_op/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_gtest_executable(test_atomic test_atomic.cpp)
|
||||
set(CTEST_OUTPUT_ON_FAILURE ON)
|
||||
407
test/ck_tile/atomic_add_op/test_atomic.cpp
Executable file
407
test/ck_tile/atomic_add_op/test_atomic.cpp
Executable file
@@ -0,0 +1,407 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights
|
||||
// reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "test_atomic.hpp"
|
||||
|
||||
struct AtomicKernelParam
|
||||
{
|
||||
AtomicKernelParam(ck_tile::index_t m_, ck_tile::index_t n_) : m(m_), n(n_) {}
|
||||
ck_tile::index_t m;
|
||||
ck_tile::index_t n;
|
||||
};
|
||||
|
||||
template <typename DataType_, ck_tile::index_t multiple_>
|
||||
class TestAtomicKernel : public ::testing::TestWithParam<std::tuple<int, int>>
|
||||
{
|
||||
struct AtomicKernelWaveSize64
|
||||
{
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<128, 8>;
|
||||
using WaveTile = ck_tile::sequence<64, 8>;
|
||||
static constexpr ck_tile::index_t kBlockSize = 128; // 2 waves * 64 lanes
|
||||
};
|
||||
|
||||
struct AtomicKernelWaveSize32
|
||||
{
|
||||
using BlockWaves = ck_tile::sequence<2, 1>;
|
||||
using BlockTile = ck_tile::sequence<64, 8>;
|
||||
using WaveTile = ck_tile::sequence<32, 8>; // 32*2 == 64
|
||||
static constexpr ck_tile::index_t kBlockSize = 64; // 2 waves * 32 lanes
|
||||
};
|
||||
|
||||
template <typename Config>
|
||||
void RunTestImpl_(const AtomicKernelParam& params, int require_warp_size, const char* tag)
|
||||
{
|
||||
// Device capability check & skip if wavesize mismatches
|
||||
int dev = 0;
|
||||
hipDeviceProp_t prop{};
|
||||
if(hipGetDevice(&dev) != hipSuccess || hipGetDeviceProperties(&prop, dev) != hipSuccess)
|
||||
{
|
||||
GTEST_SKIP() << "[" << tag << "] hipGetDeviceProperties failed; skipping.";
|
||||
}
|
||||
if(prop.warpSize != require_warp_size)
|
||||
{
|
||||
GTEST_SKIP() << "[" << tag << "] Device warpSize=" << prop.warpSize << " (requires "
|
||||
<< require_warp_size << "); skipping.";
|
||||
}
|
||||
|
||||
using XDataType = DataType_;
|
||||
|
||||
const ck_tile::index_t m = params.m;
|
||||
const ck_tile::index_t n = params.n;
|
||||
|
||||
std::cout << "[" << tag << "] Input Tensor Dimensions: " << m << ", " << n << std::endl;
|
||||
|
||||
constexpr int dword_bytes = 4;
|
||||
const int base_vec = dword_bytes / static_cast<int>(sizeof(XDataType));
|
||||
const int vec = multiple_ * base_vec;
|
||||
|
||||
ASSERT_EQ(n % vec, 0) << " Row dimension must be divisible by vector width: n=" << n
|
||||
<< " vec=" << vec << " (multiple=" << multiple_
|
||||
<< ", base_vec=" << base_vec << ")";
|
||||
|
||||
// host tensors
|
||||
ck_tile::HostTensor<XDataType> x_host_ref({m, n});
|
||||
ck_tile::HostTensor<XDataType> x_host_dev({m, n});
|
||||
|
||||
// device buffers
|
||||
ck_tile::DeviceMem x_dev_input(x_host_dev.get_element_space_size_in_bytes());
|
||||
x_dev_input.SetZero();
|
||||
x_host_ref.SetZero();
|
||||
|
||||
using BlockWaves = typename Config::BlockWaves;
|
||||
using BlockTile = typename Config::BlockTile;
|
||||
using WaveTile = typename Config::WaveTile;
|
||||
using Vector = ck_tile::sequence<1, vec>;
|
||||
|
||||
// Compile-time sanity: BlockTile == WaveTile * BlockWaves
|
||||
static_assert(BlockTile::at(ck_tile::number<0>{}) ==
|
||||
WaveTile::at(ck_tile::number<0>{}) * BlockWaves::at(ck_tile::number<0>{}),
|
||||
"BlockTile.M must equal WaveTile.M * BlockWaves.M");
|
||||
static_assert(BlockTile::at(ck_tile::number<1>{}) ==
|
||||
WaveTile::at(ck_tile::number<1>{}) * BlockWaves::at(ck_tile::number<1>{}),
|
||||
"BlockTile.N must equal WaveTile.N * BlockWaves.N");
|
||||
|
||||
std::cout << "[" << tag << "] Vector per thread = " << vec
|
||||
<< " BlockWaves=" << BlockWaves::at(ck_tile::number<0>{}) << "x"
|
||||
<< BlockWaves::at(ck_tile::number<1>{})
|
||||
<< " WaveTile=" << WaveTile::at(ck_tile::number<0>{}) << "x"
|
||||
<< WaveTile::at(ck_tile::number<1>{})
|
||||
<< " BlockTile=" << BlockTile::at(ck_tile::number<0>{}) << "x"
|
||||
<< BlockTile::at(ck_tile::number<1>{}) << std::endl;
|
||||
|
||||
const ck_tile::index_t kGridSize =
|
||||
ck_tile::integer_divide_ceil(m, BlockTile::at(ck_tile::number<0>{}));
|
||||
|
||||
using Shape = ck_tile::AtomicKernelShape<BlockWaves, BlockTile, WaveTile, Vector>;
|
||||
using Problem = ck_tile::AtomicKernelProblem<XDataType, Shape>;
|
||||
using Kernel = ck_tile::AtomicKernel<Problem>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = Config::kBlockSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
(void)hipGetLastError(); // clear sticky
|
||||
|
||||
launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_dev_input.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
ASSERT_EQ(hipPeekAtLastError(), hipSuccess)
|
||||
<< "[" << tag << "] hipPeekAtLastError: " << hipGetErrorString(hipGetLastError());
|
||||
ASSERT_EQ(hipDeviceSynchronize(), hipSuccess)
|
||||
<< "[" << tag << "] hipDeviceSynchronize failed";
|
||||
|
||||
// host reference computation
|
||||
x_dev_input.FromDevice(x_host_dev.mData.data());
|
||||
for(int i = 0; i < m; ++i)
|
||||
for(int j = 0; j < n; ++j)
|
||||
x_host_ref(i, j) = static_cast<XDataType>(1);
|
||||
|
||||
const bool pass = ck_tile::check_err(x_host_dev, x_host_ref);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
protected:
|
||||
// WaveSize = 64 path
|
||||
void RunTest(const AtomicKernelParam& params)
|
||||
{
|
||||
RunTestImpl_<AtomicKernelWaveSize64>(params, /*require_warp_size=*/64, "WS64");
|
||||
}
|
||||
|
||||
// WaveSize = 32 path
|
||||
void RunTestWave32(const AtomicKernelParam& params)
|
||||
{
|
||||
RunTestImpl_<AtomicKernelWaveSize32>(params, /*require_warp_size=*/32, "WS32");
|
||||
}
|
||||
};
|
||||
|
||||
class TestAtomicKernelHalf_1 : public TestAtomicKernel<ck_tile::half_t, 1>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelHalf_2 : public TestAtomicKernel<ck_tile::half_t, 2>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelHalf_4 : public TestAtomicKernel<ck_tile::half_t, 4>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelBF16_1 : public TestAtomicKernel<ck_tile::bf16_t, 1>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelBF16_2 : public TestAtomicKernel<ck_tile::bf16_t, 2>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelBF16_4 : public TestAtomicKernel<ck_tile::bf16_t, 4>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelBF8_1 : public TestAtomicKernel<ck_tile::bf8_t, 1>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelBF8_2 : public TestAtomicKernel<ck_tile::bf8_t, 2>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelFP8_1 : public TestAtomicKernel<ck_tile::fp8_t, 1>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelFP8_2 : public TestAtomicKernel<ck_tile::fp8_t, 2>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelFloat_1 : public TestAtomicKernel<float, 1>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelFloat_2 : public TestAtomicKernel<float, 2>
|
||||
{
|
||||
};
|
||||
class TestAtomicKernelFloat_4 : public TestAtomicKernel<float, 4>
|
||||
{
|
||||
};
|
||||
|
||||
//
|
||||
// WaveSize=64 tests (auto-skip on wave32 devices)
|
||||
//
|
||||
#if defined(CK_USE_XDL)
|
||||
TEST_P(TestAtomicKernelHalf_1, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelHalf_2, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelHalf_4, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_1, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_2, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_4, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF8_1, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF8_2, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFP8_1, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFP8_2, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_1, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_2, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_4, TestCorrectness)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTest({M, N});
|
||||
}
|
||||
|
||||
//
|
||||
// WaveSize=32 tests (auto-skip on wave64 devices)
|
||||
//
|
||||
#else
|
||||
TEST_P(TestAtomicKernelHalf_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelHalf_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelHalf_4, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF16_4, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF8_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelBF8_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFP8_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFP8_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_1, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
TEST_P(TestAtomicKernelFloat_2, TestCorrectnessWS32)
|
||||
{
|
||||
auto [M, N] = GetParam();
|
||||
this->RunTestWave32({M, N});
|
||||
}
|
||||
#endif
|
||||
|
||||
// Common parameter lists
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelHalf_1,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelHalf_2,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelHalf_4,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelBF16_1,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelBF16_2,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelBF16_4,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelBF8_1,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelBF8_2,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelFP8_1,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelFP8_2,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelFloat_1,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelFloat_2,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
INSTANTIATE_TEST_SUITE_P(TestAtomicKernelSuite,
|
||||
TestAtomicKernelFloat_4,
|
||||
::testing::Values(std::tuple{64, 8},
|
||||
std::tuple{64, 16},
|
||||
std::tuple{64, 32}));
|
||||
#endif
|
||||
115
test/ck_tile/atomic_add_op/test_atomic.hpp
Normal file
115
test/ck_tile/atomic_add_op/test_atomic.hpp
Normal file
@@ -0,0 +1,115 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWaves, typename BlockTile, typename WaveTile, typename Vector>
|
||||
struct AtomicKernelShape
|
||||
{
|
||||
static constexpr index_t MWarps = BlockWaves::at(number<0>{});
|
||||
static constexpr index_t NWarps = BlockWaves::at(number<1>{});
|
||||
|
||||
static constexpr index_t Block_M = BlockTile::at(number<0>{});
|
||||
static constexpr index_t Block_N = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Warp_M = WaveTile::at(number<0>{});
|
||||
static constexpr index_t Warp_N = WaveTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = MWarps;
|
||||
static constexpr index_t WarpPerBlock_N = NWarps;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{});
|
||||
|
||||
static constexpr index_t BlockSize = get_warp_size() * WaveNum;
|
||||
};
|
||||
|
||||
template <typename XDataType_, typename BlockShape_>
|
||||
struct AtomicKernelProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
};
|
||||
|
||||
template <typename Problem_>
|
||||
struct AtomicKernel
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using XDataType = typename Problem::XDataType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
|
||||
constexpr index_t X0 = S::ThreadPerWarp_N;
|
||||
constexpr index_t X1 = S::Vector_N;
|
||||
|
||||
constexpr index_t Y0 = S::WaveNum;
|
||||
constexpr index_t Y2 = warp_size / X0;
|
||||
constexpr index_t Y1 = S::Warp_M / Y2;
|
||||
|
||||
constexpr auto encoding =
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{};
|
||||
|
||||
return make_static_tile_distribution(encoding);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(XDataType* input, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
constexpr auto block_dims = make_tuple(number<S::Block_M>{}, number<S::Block_N>{});
|
||||
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(get_block_id() * S::Block_M);
|
||||
|
||||
const auto input_view =
|
||||
make_naive_tensor_view<address_space_enum::global, memory_operation_enum::atomic_add>(
|
||||
input, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
|
||||
auto input_window = make_tile_window(input_view, block_dims, {iM, 0});
|
||||
|
||||
const index_t num_iterations =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
|
||||
using tmp_tile =
|
||||
decltype(make_static_distributed_tensor<XDataType>(MakeTileDistribution<Problem>()));
|
||||
|
||||
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_iterations; iN++)
|
||||
{
|
||||
tmp_tile add_value_tile;
|
||||
tile_elementwise_inout([](auto& c) { c = static_cast<XDataType>(1.0f); },
|
||||
add_value_tile);
|
||||
|
||||
update_tile(input_window, add_value_tile);
|
||||
__syncthreads();
|
||||
|
||||
move_tile_window(input_window, {0, S::Block_N});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user