topk_softmax (#1592)

* topk_softmax

* remove some file

* fix atomix linear_offset

* address various comment, and change sfc get_index api to static(tuple)

[ROCm/composable_kernel commit: b098b71b05]
This commit is contained in:
carlushuang
2024-10-26 23:52:49 +08:00
committed by GitHub
parent 930195c384
commit a6b29fcb01
41 changed files with 5603 additions and 226 deletions

View File

@@ -0,0 +1,8 @@
add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})

View File

@@ -0,0 +1,28 @@
# topk-softmax
This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_topk_softmax -j
```
This will result in an executable `build/bin/tile_example_topk_softmax`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-st_o row stride of output/indices, -1 means same as topk (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```

View File

@@ -0,0 +1,22 @@
#!/bin/sh
EXE=./build/bin/tile_example_topk_softmax
for pr_i in "fp16" "bf16" ; do
$EXE -pr_i=$pr_i -t=80 -e=17
$EXE -pr_i=$pr_i -t=111 -e=117
$EXE -pr_i=$pr_i -t=1000 -e=55
$EXE -pr_i=$pr_i -t=99 -e=180
$EXE -pr_i=$pr_i -t=175 -e=64 -k=8
$EXE -pr_i=$pr_i -t=65 -e=8 -k=2
$EXE -pr_i=$pr_i -t=1 -e=25
$EXE -pr_i=$pr_i -t=31 -e=19 -k=15
$EXE -pr_i=$pr_i -t=81 -e=37 -k=7
$EXE -pr_i=$pr_i -t=199 -e=128 -k=13
$EXE -pr_i=$pr_i -t=23 -e=1 -k=1
$EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31
$EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12
$EXE -pr_i=$pr_i -t=1 -e=1 -k=1
$EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5
$EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17
done

View File

@@ -0,0 +1,299 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp"
#if 0
template <typename T>
void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 2);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto v = ck_tile::type_convert<float>(x(i, j));
std::cout << v;
if(j != len[1] - 1)
std::cout << ",";
}
else
{
std::cout << x(i, j) << " ";
}
}
std::cout << "]";
if(i != len[0] - 1)
std::cout << ",";
else
std::cout << "]";
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// CPU reference
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
ck_tile::HostTensor<WeightType>& y_values,
ck_tile::HostTensor<IndexType>& y_indices,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "32", "number of input tokens")
.insert("e", "8", "number of experts")
.insert("k", "2", "topk")
.insert("st_i", "-1", "row stride of input, -1 means same as experts")
.insert("st_o", "-1", "row stride of output/indices, -1 means same as topk")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool test_topk_softmax(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int stride_input = args.get_int("st_i");
int stride_output = args.get_int("st_o");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
if(stride_input < 0)
{
stride_input = experts;
}
if(stride_output < 0)
{
stride_output = topk;
}
assert(stride_input >= experts);
assert(stride_output >= topk);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > experts)
{
printf("topk:%d value should be smaller than, or equal to number of experts:%d\n",
topk,
experts);
return false;
}
// tokens already considered batch size
ck_tile::HostTensor<InputType> x_host({tokens, experts}, {stride_input, 1});
ck_tile::HostTensor<WeightType> value_host({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_host({tokens, topk}, {stride_output, 1});
{
// random require per-row unique
auto rand_gen = ck_tile::FillUniformDistribution_Unique<InputType>{
-5.f, 5.f, static_cast<uint32_t>(seed)};
for(int i_t = 0; i_t < tokens; i_t++)
{
ck_tile::HostTensor<InputType> x_row({experts});
rand_gen(x_row);
std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input);
rand_gen.clear();
}
}
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
topk_softmax_trait trait{input_prec, weight_prec, experts};
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
value_dev.GetDeviceBuffer(),
index_dev.GetDeviceBuffer(),
tokens,
experts,
topk,
stride_input,
stride_output};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
experts,
topk,
stride_input,
stride_output,
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
value_dev.FromDevice(value_host.data());
index_dev.FromDevice(index_host.data());
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
auto [rtol, atol] = get_elimit<InputType>("");
for(int i_t = 0; i_t < tokens; i_t++)
{
auto s_begin = std::vector<size_t>{static_cast<size_t>(i_t), static_cast<size_t>(0)};
auto s_end =
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
auto s_value_host = value_host.slice(s_begin, s_end);
auto s_value_ref = value_ref.slice(s_begin, s_end);
rtn &= ck_tile::check_err(s_value_host,
s_value_ref,
std::string("[") + std::to_string(i_t) +
std::string("] Value Error:"),
rtol,
atol);
auto s_index_host = index_host.slice(s_begin, s_end);
auto s_index_ref = index_ref.slice(s_begin, s_end);
rtn &= ck_tile::check_err(s_index_host,
s_index_ref,
std::string("[") + std::to_string(i_t) +
std::string("] Index Error:"),
rtol,
atol);
}
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0)
{
r &= test_topk_softmax<ck_tile::fp16_t, float, ck_tile::index_t>(args);
}
else if(input_prec.compare("bf16") == 0 && weight_prec.compare("fp32") == 0)
{
r &= test_topk_softmax<ck_tile::bf16_t, float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}

View File

@@ -0,0 +1,96 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "topk_softmax_api.hpp"
#define TOPK_SOFTMAX_DISPATCH(experts_) \
constexpr ck_tile::index_t ts_experts = experts_; \
using ts_problem = ck_tile:: \
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
constexpr dim3 blocks = kernel::BlockSize(); \
\
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#endif
}
return -1;
}

View File

@@ -0,0 +1,21 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct topk_softmax_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
};
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
{
};
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s);

View File

@@ -7,3 +7,5 @@ add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm)
add_subdirectory(04_img2col)
add_subdirectory(05_reduce)
add_subdirectory(09_topk_softmax)

View File

@@ -49,6 +49,7 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -81,8 +81,10 @@ struct space_filling_curve
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
}
// Do not use this function directly!
// TODO: can refactor into generic lambda in the future
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
{
#if 0
/*
@@ -153,11 +155,11 @@ struct space_filling_curve
return idx_md;
}
// FIXME: rename this function
// FIXME: return tuple of number<>, which is compile time only variable
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
{
constexpr auto idx = get_index(number<AccessIdx1d>{});
constexpr auto idx = _get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
}

View File

@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct smem_load_trait;
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
// clang-format on
} // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template <index_t>
struct smem_load;
template <>
struct smem_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u8 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
// clang-format off
namespace impl{
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc,
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_linear_addr_offset,
index_t flag = 0,
bool_constant<pre_nop> = {})
{
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
src_linear_addr_offset,
flag,
bool_constant<pre_nop>{});
}
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
0,
src_linear_addr_offset,
flag,
bool_constant<pre_nop>{});
}
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant<pre_nop>{});
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
}
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset,
index_t dst_linear_addr_offset,
index_t is_valid_element = 1)
{
constexpr index_t bytes = sizeof(T) * N;
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0,
dst_linear_addr_offset,
is_valid_element);
}
else
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
dst_linear_addr_offset);
}
}
@@ -2014,6 +2156,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<pre_nop>{});
}
@@ -2041,16 +2186,19 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t is_valid_element = 0,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<pre_nop>{});
}
@@ -2066,6 +2214,7 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size,
bool_constant<pre_nop> = {})
{
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
amd_async_buffer_load_impl<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
@@ -2086,12 +2240,42 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool_constant<pre_nop> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>(
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
amd_async_buffer_load_impl<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// buffer_store requires:
@@ -2146,6 +2330,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,

View File

@@ -41,6 +41,19 @@
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
@@ -205,3 +218,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif

View File

@@ -623,7 +623,7 @@ template <typename... Ys,
false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
@@ -635,7 +635,7 @@ template <typename... Ys,
false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
@@ -647,7 +647,7 @@ template <typename... Xs,
false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{

View File

@@ -487,55 +487,12 @@ struct log2e<float>
template <typename T = double>
constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
float abs(const float& x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
}
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
uint32_t xx = bit_cast<uint32_t>(x);
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
{
return __builtin_amdgcn_sad_u16(x, y, acc);
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
return (x > y ? (x - y) : (y - x)) + acc;
}
///////////////////////////////////////////////////////////////
} // namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace ck_tile {
#if CK_TILE_WORKAROUND_SWDEV_383542
extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST float abs(float x) { return std::abs(x); };
CK_TILE_HOST double abs(double x) { return std::abs(x); };
CK_TILE_HOST int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
}
#endif
CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(int32_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_HOST fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
};
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
template <typename T>
CK_TILE_HOST T tanh(T x)
{
return type_convert<T>(std::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
CK_TILE_HOST double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
CK_TILE_HOST T acos(T x)
{
return type_convert<T>(std::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acos<float>(float x)
{
return std::acosf(x);
};
template <>
CK_TILE_HOST double acos<double>(double x)
{
return std::acos(x);
};
template <typename T>
CK_TILE_HOST T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_HOST double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_HOST int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_HOST int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <typename T>
CK_TILE_HOST T atan(T x)
{
return type_convert<T>(std::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atan<float>(float x)
{
return std::atanf(x);
};
template <>
CK_TILE_HOST double atan<double>(double x)
{
return std::atan(x);
};
template <typename T>
CK_TILE_HOST T sin(T x)
{
return type_convert<T>(std::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sin<float>(float x)
{
return std::sinf(x);
};
template <>
CK_TILE_HOST double sin<double>(double x)
{
return std::sin(x);
};
template <typename T>
CK_TILE_HOST T asin(T x)
{
return type_convert<T>(std::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asin<float>(float x)
{
return std::asinf(x);
};
template <>
CK_TILE_HOST double asin<double>(double x)
{
return std::asin(x);
};
template <typename T>
CK_TILE_HOST T asinh(T x)
{
return type_convert<T>(std::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asinh<float>(float x)
{
return std::asinhf(x);
};
template <>
CK_TILE_HOST double asinh<double>(double x)
{
return std::asinh(x);
};
template <typename T>
CK_TILE_HOST T cos(T x)
{
return type_convert<T>(std::cosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cos<float>(float x)
{
return std::cosf(x);
};
template <>
CK_TILE_HOST double cos<double>(double x)
{
return std::cos(x);
};
template <typename T>
CK_TILE_HOST T acosh(T x)
{
return type_convert<T>(std::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acosh<float>(float x)
{
return std::acoshf(x);
};
template <>
CK_TILE_HOST double acosh<double>(double x)
{
return std::acosh(x);
};
template <typename T>
CK_TILE_HOST T tan(T x)
{
return type_convert<T>(std::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tan<float>(float x)
{
return std::tanf(x);
};
template <>
CK_TILE_HOST double tan<double>(double x)
{
return std::tan(x);
};
template <typename T>
CK_TILE_HOST T atanh(T x)
{
return type_convert<T>(std::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atanh<float>(float x)
{
return std::atanhf(x);
};
template <>
CK_TILE_HOST double atanh<double>(double x)
{
return std::atanh(x);
};
template <typename T>
CK_TILE_HOST T sinh(T x)
{
return type_convert<T>(std::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sinh<float>(float x)
{
return std::sinhf(x);
};
template <>
CK_TILE_HOST double sinh<double>(double x)
{
return std::sinh(x);
};
template <typename T>
CK_TILE_HOST T ceil(T x)
{
return type_convert<T>(std::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float ceil<float>(float x)
{
return std::ceilf(x);
};
template <>
CK_TILE_HOST double ceil<double>(double x)
{
return std::ceil(x);
};
template <typename T>
CK_TILE_HOST T cosh(T x)
{
return type_convert<T>(std::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cosh<float>(float x)
{
return std::coshf(x);
};
template <>
CK_TILE_HOST double cosh<double>(double x)
{
return std::cosh(x);
};
template <typename T>
CK_TILE_HOST T floor(T x)
{
return type_convert<T>(std::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float floor<float>(float x)
{
return std::floorf(x);
};
template <>
CK_TILE_HOST double floor<double>(double x)
{
return std::floor(x);
};
template <typename T>
CK_TILE_HOST T rcp(T x)
{
return type_convert<T>(1.f / type_convert<float>(x));
};
template <typename T>
CK_TILE_HOST T exp(T x)
{
return type_convert<T>(std::expf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float exp<float>(float x)
{
return std::expf(x);
}
template <>
CK_TILE_HOST double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
CK_TILE_HOST T log(T x)
{
return type_convert<T>(std::logf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float log<float>(float x)
{
return std::logf(x);
}
template <>
CK_TILE_HOST double log<double>(double x)
{
return std::log(x);
}
template <typename T>
CK_TILE_HOST T pow(T x, T gamma)
{
return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
}
template <>
CK_TILE_HOST float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
CK_TILE_HOST double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
CK_TILE_HOST T expm1(T x)
{
return type_convert<T>(std::expm1f(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
CK_TILE_HOST double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE float abs(float x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
};
CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
CK_TILE_DEVICE int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_DEVICE int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
};
#endif
CK_TILE_DEVICE fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_DEVICE bool isnan(int32_t x)
{
(void)x;
return false;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_DEVICE bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
template <typename T>
CK_TILE_DEVICE T tanh(T x)
{
return type_convert<T>(::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tanh<float>(float x)
{
return ::tanhf(x);
};
template <>
CK_TILE_DEVICE double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
CK_TILE_DEVICE T acos(T x)
{
return type_convert<T>(::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acos<float>(float x)
{
return ::acosf(x);
};
template <>
CK_TILE_DEVICE double acos<double>(double x)
{
return ::acos(x);
};
template <typename T>
CK_TILE_DEVICE T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_DEVICE double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_DEVICE int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
{
return __hneg(x);
};
template <typename T>
CK_TILE_DEVICE T atan(T x)
{
return type_convert<T>(::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atan<float>(float x)
{
return ::atanf(x);
};
template <>
CK_TILE_DEVICE double atan<double>(double x)
{
return ::atan(x);
};
template <typename T>
CK_TILE_DEVICE T sin(T x)
{
return type_convert<T>(::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sin<float>(float x)
{
return ::sinf(x);
};
template <>
CK_TILE_DEVICE double sin<double>(double x)
{
return ::sin(x);
};
template <>
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
{
return ::hsin(x);
};
template <typename T>
CK_TILE_DEVICE T asin(T x)
{
return type_convert<T>(::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asin<float>(float x)
{
return ::asinf(x);
};
template <>
CK_TILE_DEVICE double asin<double>(double x)
{
return ::asin(x);
};
template <typename T>
CK_TILE_DEVICE T asinh(T x)
{
return type_convert<T>(::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asinh<float>(float x)
{
return ::asinhf(x);
};
template <>
CK_TILE_DEVICE double asinh<double>(double x)
{
return ::asinh(x);
};
template <typename T>
CK_TILE_DEVICE T acosh(T x)
{
return type_convert<T>(::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acosh<float>(float x)
{
return ::acoshf(x);
};
template <>
CK_TILE_DEVICE double acosh<double>(double x)
{
return ::acosh(x);
};
template <typename T>
CK_TILE_DEVICE T tan(T x)
{
return type_convert<T>(::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tan<float>(float x)
{
return ::tanf(x);
};
template <>
CK_TILE_DEVICE double tan<double>(double x)
{
return ::tan(x);
};
template <typename T>
CK_TILE_DEVICE T atanh(T x)
{
return type_convert<T>(::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atanh<float>(float x)
{
return ::atanhf(x);
};
template <>
CK_TILE_DEVICE double atanh<double>(double x)
{
return ::atanh(x);
};
template <typename T>
CK_TILE_DEVICE T sinh(T x)
{
return type_convert<T>(::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sinh<float>(float x)
{
return ::sinhf(x);
};
template <>
CK_TILE_DEVICE double sinh<double>(double x)
{
return ::sinh(x);
};
template <typename T>
CK_TILE_DEVICE T ceil(T x)
{
return type_convert<T>(::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float ceil<float>(float x)
{
return ::ceilf(x);
};
template <>
CK_TILE_DEVICE double ceil<double>(double x)
{
return ::ceil(x);
};
template <>
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
{
return ::hceil(x);
};
template <typename T>
CK_TILE_DEVICE T cosh(T x)
{
return type_convert<T>(::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float cosh<float>(float x)
{
return ::coshf(x);
};
template <>
CK_TILE_DEVICE double cosh<double>(double x)
{
return ::cosh(x);
};
template <typename T>
CK_TILE_DEVICE T floor(T x)
{
return type_convert<T>(::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float floor<float>(float x)
{
return ::floorf(x);
};
template <>
CK_TILE_DEVICE double floor<double>(double x)
{
return ::floor(x);
};
template <>
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
{
return ::hfloor(x);
};
template <typename T>
CK_TILE_DEVICE T rcp(T x)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return __frcp_rn(x);
#else
// return __ocml_native_recip_f32(x);
return __builtin_amdgcn_rcpf(x);
#endif
};
template <typename T>
CK_TILE_DEVICE T exp(T x)
{
return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
{
return hexp(x);
};
template <>
CK_TILE_DEVICE float exp<float>(float x)
{
return __ocml_exp_f32(x);
};
template <>
CK_TILE_DEVICE double exp<double>(double x)
{
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{
return type_convert<T>(__logf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
{
return hlog(x);
};
template <>
CK_TILE_DEVICE float log<float>(float x)
{
return __logf(x);
};
template <>
CK_TILE_DEVICE double log<double>(double x)
{
return log(x);
};
template <typename T>
CK_TILE_DEVICE T pow(T x, T gamma)
{
return type_convert<T>(powf(type_convert<float>(x), type_convert<float>(gamma)));
};
template <>
CK_TILE_DEVICE float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
CK_TILE_DEVICE double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
CK_TILE_DEVICE T expm1(T x)
{
return type_convert<T>(expm1f(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float expm1<float>(float x)
{
return expm1f(x);
};
template <>
CK_TILE_DEVICE double expm1<double>(double x)
{
return expm1(x);
};
} // namespace ck_tile

View File

@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE constexpr auto get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
return tmp;
#else
return *c_style_pointer_cast<const X*>(&p_data_[i]);
return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
#endif
}
else
@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
this->template set<X>(i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
}
}
@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
#endif
}
}
@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE constexpr auto get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
t_per_x,
Coherence,
oob_conditional_check>(
p_data_, i, is_valid_element, buffer_size_);
p_data_, i + linear_offset, is_valid_element, buffer_size_);
}
else
{
@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check>(
p_data_, i, is_valid_element, buffer_size_, invalid_element_value_);
oob_conditional_check>(p_data_,
i + linear_offset,
is_valid_element,
buffer_size_,
invalid_element_value_);
}
}
else
@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
return tmp;
#else
return *c_style_pointer_cast<const X*>(&p_data_[i]);
return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
#endif
}
else
@@ -352,7 +359,8 @@ struct buffer_view<address_space_enum::global,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t i,
index_t v_offset,
index_t i_offset,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
@@ -366,7 +374,38 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem,
index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem,
cached_buf_res_,
i,
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// i is offset of T, not X. i should be aligned to X
@@ -378,6 +417,7 @@ struct buffer_view<address_space_enum::global,
bool>::type = false>
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
index_t i,
index_t linear_offset,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{
@@ -391,7 +431,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
}
// i is offset of T, not X. i should be aligned to X
@@ -401,25 +441,25 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
this->template set<X>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, is_valid_element, x);
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, is_valid_element, x);
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
@@ -432,7 +472,7 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -453,7 +493,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
x, p_data_, i, is_valid_element, buffer_size_);
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
}
else
{
@@ -462,9 +502,9 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
#endif
}
}
@@ -477,7 +517,7 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -489,7 +529,7 @@ struct buffer_view<address_space_enum::global,
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
x, p_data_, i, is_valid_element, buffer_size_);
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
@@ -497,7 +537,8 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void
atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
@@ -532,13 +573,13 @@ struct buffer_view<address_space_enum::global,
if constexpr(use_amd_buffer_addressing)
{
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_);
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
}
else
{
if(is_valid_element)
{
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
}
}
}
@@ -548,7 +589,8 @@ struct buffer_view<address_space_enum::global,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void
atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -572,11 +614,11 @@ struct buffer_view<address_space_enum::global,
if constexpr(use_amd_buffer_addressing)
{
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, buffer_size_);
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
}
else if(is_valid_element)
{
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
}
}
@@ -668,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE constexpr auto get(index_t i,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -684,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp;
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
return tmp;
#else
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]);
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
return bit_cast<X>(rtn);
#endif
}
@@ -708,6 +752,23 @@ struct buffer_view<address_space_enum::lds,
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t v_offset,
index_t i_offset,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
@@ -715,17 +776,17 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
this->template set<X>(i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
}
}
@@ -735,7 +796,7 @@ struct buffer_view<address_space_enum::lds,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -751,6 +812,7 @@ struct buffer_view<address_space_enum::lds,
bool constexpr workaround_int8_ds_write_issue = false;
#endif
i += linear_offset; // simplicity
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
int8_t>::value &&
workaround_int8_ds_write_issue)
@@ -952,8 +1014,10 @@ struct buffer_view<address_space_enum::vgpr,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
CK_TILE_DEVICE constexpr auto get(index_t i,
index_t /*linear_offset*/,
bool is_valid_element,
bool_constant<oob_conditional_check> = {}) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -995,17 +1059,17 @@ struct buffer_view<address_space_enum::vgpr,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, is_valid_element, x);
this->template set<X>(i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, is_valid_element);
this->template set<X>(i, is_valid_element, x + tmp);
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
}
}
@@ -1015,7 +1079,7 @@ struct buffer_view<address_space_enum::vgpr,
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
@@ -1030,9 +1094,9 @@ struct buffer_view<address_space_enum::vgpr,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X tmp = x;
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
#endif
}
}

View File

@@ -12,6 +12,7 @@
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
@@ -28,7 +29,21 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(bool_constant<oob_conditional_check>{});
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
}
template <typename T,
@@ -46,7 +61,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
@@ -66,7 +101,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)

View File

@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_array(
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},

View File

@@ -10,6 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -72,7 +73,7 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor);
tile_window.store(dstr_tensor, number<-1>{});
}
template <typename BottomTensorView_,
@@ -87,7 +88,33 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor);
tile_window.store_raw(dstr_tensor, number<-1>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
CK_TILE_DEVICE void store_tile(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store(dstr_tensor, number<-1>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_>
CK_TILE_DEVICE void store_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
{
tile_window.store_raw(dstr_tensor, number<-1>{});
}
} // namespace ck_tile

View File

@@ -16,6 +16,24 @@
namespace ck_tile {
/*
* tensor_view
* abstract the underneath memory buffer(global, LDS, etc...)
* and provide a unified get/set function for access
*
* For addressing into the buffer we use 2 variable to control:
* coord : ND tensor coordinate, will calculate the actual offset inside
* linear_offset : 1D offset, will be used in the immediate field of
* the buffer instruction to help reduce register usage
*
* User can use either of the field, or both to indexing into the tensor
*
* We usually provide 2 set of API for buffer get/set, e.g.
* get_vectorized_elements()/get_vectorized_elements_raw()
* the former usually will call intrinsic or normal C function, the later
* usually will call inline-asm function
*
*/
template <typename BufferView_,
typename TensorDesc_,
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
@@ -49,22 +67,6 @@ struct tensor_view
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
@@ -75,14 +77,34 @@ struct tensor_view
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
get_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element, // flag
bool_constant<oob_conditional_check> = {}) const
{
return buf_.template get<X>(coord.get_offset(),
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
@@ -94,12 +116,90 @@ struct tensor_view
bool>::type = false>
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
index_t linear_offset,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset) const
{
return buf_.template async_get<X>(
smem,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element) const
{
return buf_.template async_get<X>(smem,
coord.get_offset(),
linear_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
}
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem,
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}
@@ -110,11 +210,15 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
CK_TILE_HOST_DEVICE constexpr void
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
}
// X is vector of DataType.
@@ -125,11 +229,15 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
@@ -140,15 +248,53 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
set_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template set_raw<X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
@@ -157,15 +303,36 @@ struct tensor_view
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
buf_.template update<DstInMemOp, X, oob_conditional_check>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");

View File

@@ -18,6 +18,8 @@
namespace ck_tile {
// Note: this tile window do not support single issue
// you need to use tile_window_linear structure for this purpose
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -41,6 +43,7 @@ struct tile_window_with_static_distribution
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
@@ -189,7 +192,8 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -222,10 +226,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const AdaptorTopIndex& idx_diff_adaptor_top) const
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
@@ -279,10 +284,11 @@ struct tile_window_with_static_distribution
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
@@ -308,11 +314,11 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
@@ -338,8 +344,9 @@ struct tile_window_with_static_distribution
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -350,8 +357,12 @@ struct tile_window_with_static_distribution
return dst_tensor;
}
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
template <typename DstTile,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
@@ -397,6 +408,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
bottom_tensor_thread_coord,
0 /**/,
bool_constant<oob_conditional_check>{},
pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
@@ -409,23 +421,24 @@ struct tile_window_with_static_distribution
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
"scratch memory" ::);
#endif
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
@@ -467,7 +480,7 @@ struct tile_window_with_static_distribution
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
@@ -482,15 +495,16 @@ struct tile_window_with_static_distribution
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
smem, bottom_tensor_thread_coord, pre_nop_);
smem, bottom_tensor_thread_coord, 0, pre_nop_);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -501,8 +515,81 @@ struct tile_window_with_static_distribution
});
}
template <bool oob_conditional_check = true>
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{}));
constexpr index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
size_per_buf;
constexpr index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR LdsDataType* smem =
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// read from bottom tensor
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
smem += size_per_issue; // Note we manually increase the per-issue offset
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
@@ -515,7 +602,6 @@ struct tile_window_with_static_distribution
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
@@ -530,7 +616,7 @@ struct tile_window_with_static_distribution
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
@@ -548,15 +634,19 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -565,8 +655,9 @@ struct tile_window_with_static_distribution
});
}
CK_TILE_DEVICE void
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
template <index_t i_access_unsupport_ = -1>
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {}) const
{
using Traits = load_store_traits;
@@ -591,7 +682,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
@@ -606,15 +697,16 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view()
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
bottom_tensor_thread_coord, vec_value);
bottom_tensor_thread_coord, 0, vec_value);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -623,8 +715,9 @@ struct tile_window_with_static_distribution
});
}
template <bool oob_conditional_check = true>
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
@@ -650,7 +743,7 @@ struct tile_window_with_static_distribution
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
@@ -666,15 +759,19 @@ struct tile_window_with_static_distribution
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys =
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -746,7 +843,8 @@ struct tile_window_with_static_distribution
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
@@ -798,6 +896,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution};
}
// this version can't be called in a constexpr context
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE auto
make_tile_window_raw(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
number<NumCoord> = {})
{
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution};
w.init_raw();
return w;
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -922,6 +1041,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
w.init_raw();
return w;
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,

File diff suppressed because it is too large Load Diff

View File

@@ -59,8 +59,16 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
if(__builtin_is_constant_evaluated())
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
else
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
}
CK_TILE_HOST static constexpr uint32_t
@@ -77,9 +85,18 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
if(__builtin_is_constant_evaluated())
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
else
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
}
CK_TILE_HOST static constexpr int32_t

View File

@@ -24,5 +24,6 @@
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -10,6 +10,7 @@
#include <random>
#include <type_traits>
#include <utility>
#include <unordered_set>
#include "ck_tile/core.hpp"
@@ -41,6 +42,73 @@ struct FillUniformDistribution
}
};
namespace impl {
// clang-format off
template<index_t bytes> struct RawIntegerType_ {};
template<> struct RawIntegerType_<1> { using type = uint8_t;};
template<> struct RawIntegerType_<2> { using type = uint16_t;};
template<> struct RawIntegerType_<4> { using type = uint32_t;};
template<> struct RawIntegerType_<8> { using type = uint64_t;};
// clang-format on
template <typename T>
using RawIntegerType = typename RawIntegerType_<sizeof(T)>::type;
} // namespace impl
// Note: this struct will have no const-ness will generate random
template <typename T>
struct FillUniformDistribution_Unique
{
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
std::mt19937 gen_{};
std::unordered_set<impl::RawIntegerType<T>> set_{};
FillUniformDistribution_Unique(float a = -5.f,
float b = 5.f,
std::optional<uint32_t> seed = {11939})
: a_(a),
b_(b),
seed_(seed),
gen_{seed_.has_value() ? *seed_ : std::random_device{}()},
set_{}
{
}
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last)
{
std::mt19937& gen = gen_;
std::uniform_real_distribution<float> dis(a_, b_);
auto& set = set_;
std::generate(first, last, [&dis, &gen, &set]() {
T v = static_cast<T>(0);
do
{
v = ck_tile::type_convert<T>(dis(gen));
} while(set.count(bit_cast<impl::RawIntegerType<T>>(v)) == 1);
set.insert(bit_cast<impl::RawIntegerType<T>>(v));
return v;
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range)
-> std::void_t<decltype(std::declval<FillUniformDistribution_Unique&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
void clear() { set_.clear(); }
};
template <typename T>
struct FillNormalDistribution
{

View File

@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
@@ -545,6 +546,28 @@ struct HostTensor
typename Data::size_type size() const { return mData.size(); }
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
{
assert(s_begin.size() == s_end.size());
assert(s_begin.size() == get_num_of_dimension());
std::vector<size_t> s_len(s_begin.size());
std::transform(
s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
HostTensor<T> sliced_tensor(s_len);
sliced_tensor.ForEach([&](auto& self, auto idx) {
std::vector<size_t> src_idx(idx.size());
std::transform(
idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
self(idx) = operator()(src_idx);
});
return sliced_tensor;
}
template <typename U = T>
auto AsSpan() const
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,43 +9,81 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
index_t rank = x.get_num_of_dimension();
assert(rank == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
// max
for(int n = 0; n < N; ++n)
auto f = [&](auto i_element) {
std::vector<size_t> coord = [&]() {
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
// compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
v_max = v_max < v_a ? v_a : v_max;
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
}
AccDataType v_exp_sum = 0;
ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
v_exp_sum += ck_tile::exp(v_a - v_max);
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
y(c_) = ck_tile::type_convert<OutputType>(out);
}
};
make_ParallelTensorFunctor(f,
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
}
} // namespace ck_tile

View File

@@ -0,0 +1,124 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
namespace ck_tile {
/*
similiar to torch.topk()
x (Tensor) the input tensor.
k (int) the k in “top-k”
dim (int, optional) the dimension to sort along
largest (bool, optional) largest or smallest elements
sorted (bool, optional) elements in sorted order or not
output:
y_values
y_indices
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
*/
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
HostTensor<DataType>& y_values,
HostTensor<IndexType>& y_indices,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
// rank must be the same
index_t rank = x.get_num_of_dimension();
assert(rank == y_values.get_num_of_dimension());
assert(rank == y_indices.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
index_t topk_src_len = x.get_length(topk_dim);
auto x_len = x.get_lengths();
assert(k <= topk_src_len);
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
index_t n_parallel = x.get_element_size() / topk_src_len;
// clang-format off
auto f = [&](auto i_element) {
std::vector<size_t> topk_coord = [&](){
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--) {
if(i == topk_dim) continue; // topk dim should be zero
t_[i] = r % x_len[i]; r = r / x_len[i];
}
return t_;
}();
using elem_t = std::pair<DataType, IndexType>;
std::vector<elem_t> q = [&](){
std::vector<elem_t> t_(topk_src_len);
for(index_t i = 0; i < topk_src_len; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
t_[i].first = x(c_); t_[i].second = i;
}
return t_;
}();
// run topk
if(largest) {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
}
} else {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
}
}
// write out
for(index_t i = 0; i < k; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
}
};
// clang-format on
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile

View File

@@ -0,0 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

File diff suppressed because it is too large Load Diff

View File

@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,

View File

@@ -4,9 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include <tuple>
namespace ck_tile {
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// xor
index_t src_lane =
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
acc_tensor.get_thread_buffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
return acc_tensor;
}
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
// this version only support in/acc/out datatypes are the same
// this version will call thread/warp+sync in one function call
//
template <typename InDistributedTensor_>
struct BlockReduce2D
{
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
using InDataType = typename InDistributedTensor::DataType;
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
: t(t_), reduce_init(reduce_init_)
{
}
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
{
using ReduceDim = sequence<1>; // hard coded
constexpr auto acc_dstr =
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
InDistributedTensor::get_tile_distribution()
.get_static_tile_distribution_encoding(),
ReduceDim{}));
return make_static_distributed_tensor<InDataType>(acc_dstr);
}
// return number of pixels each lane need to reduce
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
}
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
// internally
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
// element , and the first element is always ignored For simplicity, will always try from
// right-to-left to find alone which Y dim to split
template <typename ReduceFunc,
typename ReduceSyncFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
const ReduceSyncFunc& reduce_sync_func,
ReducePacksPerXDim = {}) const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
constexpr auto row_y_unpacks = [&]() {
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
constexpr auto row_y_size =
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
static_assert(row_y_size % row_y_packs == 0);
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
constexpr auto unpacks = slice_info[number<1>{}];
return unpacks;
}();
auto acc_tensor = MakeDstBlockTile();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
auto acc = acc_tensor[acc_dstr_idx];
sweep_tile_uspan(
spans[number<1>{}],
[&](auto... dstr_idx_i1) {
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
},
row_y_unpacks);
acc_tensor(acc_dstr_idx) = acc;
});
// TODO: always use xor to do cross-lane reduce
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
return acc_tensor;
}
template <typename ReduceFunc>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
{
return operator()(reduce_func, reduce_func);
}
InDistributedTensor t;
InDataType reduce_init;
};
// deduction guide
template <typename T>
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
} // namespace ck_tile

View File

@@ -0,0 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
namespace ck_tile {
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template <typename Problem_, typename Policy_ = void>
struct BlockSoftmax2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE void
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
{
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
#if _BLOCK_SOFTMAX_USE_UNPACK2
const auto f_max3 = [](auto e0, auto e1, auto e2) {
float rtn;
asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
return rtn;
};
const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
#endif
// compute row max
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
#else
auto row_max = reduce_row_max(f_max);
#endif
sweep_tile<DistributedTensor>([&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
y(idx) = exp(x[idx] - row_max[row_id]);
});
// compute row sum
auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
#else
auto row_sum = reduce_row_sum(f_sum);
#endif
// reciprocal
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
// scale
sweep_tile<DistributedTensor>([&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
y(idx) = y(idx) * r(row_id);
});
}
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
{
auto y = DistributedTensor{}; // distributed tensor
operator()(x, y, number<dim>{});
return y;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,16 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename DataType_>
struct BlockSoftmax2DProblem
{
using DataType = remove_cvref_t<DataType_>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,113 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename Problem_, typename Policy_ = void>
struct BlockTopkStream2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
using IndexType = typename Problem::IndexType;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct ArgmaxPacket
{
DataType arg;
index_t value;
};
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
const OutWindow& out_window,
const IdxWindow& idx_window,
index_t k,
number<dim> = {})
{
OutWindow out_window_tmp = out_window;
IdxWindow idx_window_tmp = idx_window;
static_assert(
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
std::is_same_v<typename DistributedTensor::DataType, DataType>);
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
DistributedTensor x_tmp = x;
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
// argmax for topk
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
return e0.arg > e1.arg ? e0 : e1;
};
for(index_t i_k = 0; i_k < k; i_k++)
{
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
auto packet = [&]() {
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket t;
t.arg = x_tmp(i_j_idx); // !!! we reference x here
t.value = tile_idx.at(number<1>{});
tmp(i_j_idx) = t;
});
});
return tmp;
}();
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
block_tile_reduce_xor_sync(r, f_argmax);
auto o = make_static_distributed_tensor<DataType>(dst_dist);
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket tmp = r(i_j_idx);
o(i_j_idx) = tmp.arg;
i(i_j_idx) = tmp.value;
});
});
// update value
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
x.get_tile_distribution(), make_tuple(idx0, idx1));
auto col_id = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
: x_tmp(i_j_idx);
});
});
if(threadIdx.x % Problem::ColLanes == 0)
{
store_tile(out_window_tmp, o);
store_tile(idx_window_tmp, i);
}
move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename DataType_, typename IndexType_, index_t ColLanes_>
struct BlockTopkStream2DProblem
{
using DataType = remove_cvref_t<DataType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t ColLanes = ColLanes_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,166 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct TopkSoftmaxHostArgs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
};
template <typename Pipeline_>
struct TopkSoftmaxKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using InputType = typename Problem::InputType;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
struct TopkSoftmaxKargs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
};
using Kargs = TopkSoftmaxKargs;
using Hargs = TopkSoftmaxHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
if constexpr(Problem::LaunchType > 0)
{
int num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return dim3(num_cu * Problem::LaunchType);
}
else
{
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
const int num_blocks =
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
return dim3(num_blocks);
}
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.p_indices = h.p_indices;
k.num_rows = h.num_rows;
k.num_experts = h.num_experts;
k.topk = h.topk;
k.stride_input = h.stride_input;
k.stride_output = h.stride_output;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
if(block_row_id > kargs.num_rows)
return;
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
const auto input_window = [&]() {
const InputType* p_input =
reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(num_rows_rem, kargs.num_experts),
make_tuple(kargs.stride_input, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view = pad_tensor_view(
tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
return make_tile_window(
view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{0, 0});
}();
auto output_window = [&]() {
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_output,
make_tuple(num_rows_rem, kargs.topk),
make_tuple(kargs.stride_output, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}();
auto indices_window = [&]() {
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_indices,
make_tuple(num_rows_rem, kargs.topk),
make_tuple(kargs.stride_output, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}();
Pipeline{}(input_window,
output_window,
indices_window,
kargs.num_rows,
kargs.num_experts,
kargs.topk,
block_row_id);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace ck_tile {
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window,
IndexWindow& idx_window,
index_t rows,
index_t experts,
index_t k,
index_t block_row_id)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
auto inp_win = make_tile_window_linear_raw(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#else
auto inp_win = make_tile_window_linear(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#endif
auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
out_window.get_window_lengths(),
out_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
idx_window.get_window_lengths(),
idx_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto softmax = Policy::template GetSoftmax<Problem>();
auto topk = Policy::template GetTopk<Problem>();
const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
while(1)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0);
auto x =
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0);
#else
auto x = load_tile(inp_win);
#endif
// cast and pad input data
auto w = [&]() {
#if 0
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
: w_(i_j_idx);
});
});
return w_;
#else
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
auto w_f = [&](auto idx) {
w_(idx) = type_convert<WeightType>(x(idx));
const auto x_indices =
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
const auto current_expert = x_indices.at(number<1>{});
w_(idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
};
tile_sweeper ts{w_, w_f};
ts();
return w_;
#endif
}();
// softmax
auto y = softmax(w);
topk(y, out_win, idx_win, k);
// check exit
if constexpr(Problem::LaunchType == 0)
{
break;
}
else
{
block_row_id += grid_rows_per_loop;
if(block_row_id >= rows)
break;
}
move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct TopkSoftmaxWarpPerRowPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
// TODO: Y dim must have one dim that is not reduced
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<1>>, // each row write out single element
tuple<sequence<1>, sequence<1, 0>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
{
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
return BlockSoftmax2D<softmax_problem>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
{
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
typename Problem::IndexType,
Problem::LanesPerRow>;
// Note: replicate is LanesPerRow
return BlockTopkStream2D<topk_problem>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,46 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename InputType_,
typename WeightType_,
typename IndexType_,
index_t Experts_,
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256>
struct TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using InputType = remove_cvref_t<InputType_>;
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t LaunchType = LaunchType_;
static constexpr index_t Experts = Experts_;
static constexpr index_t BytesPerIssue = BytesPerIssue_;
static constexpr index_t IssuesPerCol = IssuesPerCol_;
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size();
static_assert(BytesPerIssue % sizeof(InputType) == 0);
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
static_assert(Experts % VectorSize == 0);
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
static_assert(WarpSize % LanesPerRow == 0);
static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
};
} // namespace ck_tile