mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
topk_softmax (#1592)
* topk_softmax
* remove some file
* fix atomix linear_offset
* address various comment, and change sfc get_index api to static(tuple)
[ROCm/composable_kernel commit: b098b71b05]
This commit is contained in:
8
example/ck_tile/09_topk_softmax/CMakeLists.txt
Normal file
8
example/ck_tile/09_topk_softmax/CMakeLists.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
28
example/ck_tile/09_topk_softmax/README.md
Normal file
28
example/ck_tile/09_topk_softmax/README.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# topk-softmax
|
||||
|
||||
This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_topk_softmax -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_topk_softmax`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-st_o row stride of output/indices, -1 means same as topk (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
|
||||
```
|
||||
22
example/ck_tile/09_topk_softmax/script/smoke_test.sh
Normal file
22
example/ck_tile/09_topk_softmax/script/smoke_test.sh
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_topk_softmax
|
||||
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -pr_i=$pr_i -t=80 -e=17
|
||||
$EXE -pr_i=$pr_i -t=111 -e=117
|
||||
$EXE -pr_i=$pr_i -t=1000 -e=55
|
||||
$EXE -pr_i=$pr_i -t=99 -e=180
|
||||
$EXE -pr_i=$pr_i -t=175 -e=64 -k=8
|
||||
$EXE -pr_i=$pr_i -t=65 -e=8 -k=2
|
||||
$EXE -pr_i=$pr_i -t=1 -e=25
|
||||
$EXE -pr_i=$pr_i -t=31 -e=19 -k=15
|
||||
$EXE -pr_i=$pr_i -t=81 -e=37 -k=7
|
||||
$EXE -pr_i=$pr_i -t=199 -e=128 -k=13
|
||||
$EXE -pr_i=$pr_i -t=23 -e=1 -k=1
|
||||
$EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31
|
||||
$EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12
|
||||
$EXE -pr_i=$pr_i -t=1 -e=1 -k=1
|
||||
$EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5
|
||||
$EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17
|
||||
done
|
||||
299
example/ck_tile/09_topk_softmax/topk_softmax.cpp
Normal file
299
example/ck_tile/09_topk_softmax/topk_softmax.cpp
Normal file
@@ -0,0 +1,299 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <time.h>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
|
||||
{
|
||||
auto len = x.get_lengths();
|
||||
assert(len.size() == 2);
|
||||
std::cout << "[";
|
||||
for(size_t i = 0; i < len[0]; i++)
|
||||
{
|
||||
std::cout << i << ": [";
|
||||
for(size_t j = 0; j < len[1]; j++)
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
|
||||
{
|
||||
auto v = ck_tile::type_convert<float>(x(i, j));
|
||||
|
||||
std::cout << v;
|
||||
if(j != len[1] - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << x(i, j) << " ";
|
||||
}
|
||||
}
|
||||
std::cout << "]";
|
||||
if(i != len[0] - 1)
|
||||
std::cout << ",";
|
||||
else
|
||||
std::cout << "]";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "--------------------" << std::endl;
|
||||
}
|
||||
#endif
|
||||
|
||||
// CPU reference
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
|
||||
|
||||
auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted);
|
||||
|
||||
return ck_tile::make_tuple(y_values, y_indices);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
|
||||
ck_tile::HostTensor<WeightType>& y_values,
|
||||
ck_tile::HostTensor<IndexType>& y_indices,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
|
||||
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "32", "number of input tokens")
|
||||
.insert("e", "8", "number of experts")
|
||||
.insert("k", "2", "topk")
|
||||
.insert("st_i", "-1", "row stride of input, -1 means same as experts")
|
||||
.insert("st_o", "-1", "row stride of output/indices, -1 means same as topk")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string input_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int stride_input = args.get_int("st_i");
|
||||
int stride_output = args.get_int("st_o");
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
if(stride_input < 0)
|
||||
{
|
||||
stride_input = experts;
|
||||
}
|
||||
if(stride_output < 0)
|
||||
{
|
||||
stride_output = topk;
|
||||
}
|
||||
assert(stride_input >= experts);
|
||||
assert(stride_output >= topk);
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
if(topk > experts)
|
||||
{
|
||||
printf("topk:%d value should be smaller than, or equal to number of experts:%d\n",
|
||||
topk,
|
||||
experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<InputType> x_host({tokens, experts}, {stride_input, 1});
|
||||
ck_tile::HostTensor<WeightType> value_host({tokens, topk}, {stride_output, 1});
|
||||
ck_tile::HostTensor<IndexType> index_host({tokens, topk}, {stride_output, 1});
|
||||
|
||||
{
|
||||
// random require per-row unique
|
||||
auto rand_gen = ck_tile::FillUniformDistribution_Unique<InputType>{
|
||||
-5.f, 5.f, static_cast<uint32_t>(seed)};
|
||||
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
{
|
||||
ck_tile::HostTensor<InputType> x_row({experts});
|
||||
rand_gen(x_row);
|
||||
std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input);
|
||||
rand_gen.clear();
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
topk_softmax_trait trait{input_prec, weight_prec, experts};
|
||||
|
||||
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
|
||||
value_dev.GetDeviceBuffer(),
|
||||
index_dev.GetDeviceBuffer(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = topk_softmax(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
|
||||
input_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output,
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
value_dev.FromDevice(value_host.data());
|
||||
index_dev.FromDevice(index_host.data());
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
|
||||
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
|
||||
|
||||
reference_topk_softmax<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
|
||||
auto [rtol, atol] = get_elimit<InputType>("");
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
{
|
||||
auto s_begin = std::vector<size_t>{static_cast<size_t>(i_t), static_cast<size_t>(0)};
|
||||
auto s_end =
|
||||
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
|
||||
auto s_value_host = value_host.slice(s_begin, s_end);
|
||||
auto s_value_ref = value_ref.slice(s_begin, s_end);
|
||||
rtn &= ck_tile::check_err(s_value_host,
|
||||
s_value_ref,
|
||||
std::string("[") + std::to_string(i_t) +
|
||||
std::string("] Value Error:"),
|
||||
rtol,
|
||||
atol);
|
||||
auto s_index_host = index_host.slice(s_begin, s_end);
|
||||
auto s_index_ref = index_ref.slice(s_begin, s_end);
|
||||
rtn &= ck_tile::check_err(s_index_host,
|
||||
s_index_ref,
|
||||
std::string("[") + std::to_string(i_t) +
|
||||
std::string("] Index Error:"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string input_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0)
|
||||
{
|
||||
r &= test_topk_softmax<ck_tile::fp16_t, float, ck_tile::index_t>(args);
|
||||
}
|
||||
else if(input_prec.compare("bf16") == 0 && weight_prec.compare("fp32") == 0)
|
||||
{
|
||||
r &= test_topk_softmax<ck_tile::bf16_t, float, ck_tile::index_t>(args);
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
96
example/ck_tile/09_topk_softmax/topk_softmax_api.cpp
Normal file
96
example/ck_tile/09_topk_softmax/topk_softmax_api.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "topk_softmax_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DISPATCH(experts_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
using ts_problem = ck_tile:: \
|
||||
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
constexpr dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
#if 1
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
}
|
||||
#else
|
||||
if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32")
|
||||
{
|
||||
#if 1
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
21
example/ck_tile/09_topk_softmax/topk_softmax_api.hpp
Normal file
21
example/ck_tile/09_topk_softmax/topk_softmax_api.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/topk_softmax.hpp"
|
||||
#include <string>
|
||||
|
||||
struct topk_softmax_trait
|
||||
{
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
};
|
||||
|
||||
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s);
|
||||
@@ -7,3 +7,5 @@ add_subdirectory(02_layernorm2d)
|
||||
add_subdirectory(03_gemm)
|
||||
add_subdirectory(04_img2col)
|
||||
add_subdirectory(05_reduce)
|
||||
add_subdirectory(09_topk_softmax)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -81,8 +81,10 @@ struct space_filling_curve
|
||||
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
|
||||
}
|
||||
|
||||
// Do not use this function directly!
|
||||
// TODO: can refactor into generic lambda in the future
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
|
||||
static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
@@ -153,11 +155,11 @@ struct space_filling_curve
|
||||
return idx_md;
|
||||
}
|
||||
|
||||
// FIXME: rename this function
|
||||
// FIXME: return tuple of number<>, which is compile time only variable
|
||||
template <index_t AccessIdx1d>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
|
||||
static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
|
||||
{
|
||||
constexpr auto idx = get_index(number<AccessIdx1d>{});
|
||||
constexpr auto idx = _get_index(number<AccessIdx1d>{});
|
||||
|
||||
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
|
||||
}
|
||||
|
||||
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
template<index_t N, typename T> struct smem_load_trait;
|
||||
|
||||
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
|
||||
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
|
||||
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
|
||||
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
|
||||
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
|
||||
template <index_t>
|
||||
struct smem_load;
|
||||
|
||||
template <>
|
||||
struct smem_load<16>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 16);
|
||||
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
|
||||
asm volatile("ds_read_b128 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<8>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 8);
|
||||
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
|
||||
asm volatile("ds_read_b64 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<4>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
|
||||
asm volatile("ds_read_b32 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<2>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
|
||||
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
|
||||
asm volatile("ds_read_u16 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct smem_load<1>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
|
||||
asm volatile("ds_read_u8 %0, %1 offset:%2"
|
||||
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
|
||||
: "v"(v_offset), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
namespace impl{
|
||||
|
||||
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <bool pre_nop = false>
|
||||
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
|
||||
int32x4_t rsrc,
|
||||
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_linear_addr_offset,
|
||||
index_t flag = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
flag,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
smem,
|
||||
sizeof(uint32_t),
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
src_immediate_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset,
|
||||
index_t dst_linear_addr_offset,
|
||||
index_t is_valid_element = 1)
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
is_valid_element);
|
||||
}
|
||||
else
|
||||
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
dst_linear_addr_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2014,6 +2156,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_element_space_size,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -2041,16 +2186,19 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t is_valid_element = 0,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -2066,6 +2214,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
index_t src_element_space_size,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
amd_async_buffer_load_impl<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
@@ -2086,12 +2240,42 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load_impl<T, N, coherence>(
|
||||
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{});
|
||||
amd_async_buffer_load_impl<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = false>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
|
||||
const int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_element_offset,
|
||||
index_t src_linear_element_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_async_buffer_load<T, N, coherence>(smem,
|
||||
src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
0,
|
||||
src_linear_addr_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
@@ -2146,6 +2330,7 @@ template <typename T,
|
||||
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const index_t dst_linear_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
|
||||
|
||||
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
dst_thread_element_valid);
|
||||
}
|
||||
|
||||
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
CK_TILE_DEVICE_EXTERN void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
|
||||
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const index_t global_offset,
|
||||
|
||||
@@ -41,6 +41,19 @@
|
||||
#define CK_TILE_HOST_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
// implementing the "memory address space" attribute
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
|
||||
#ifdef __HIPCC_
|
||||
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
|
||||
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
|
||||
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
|
||||
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
|
||||
#else
|
||||
#define CK_TILE_GENERIC_ADDR
|
||||
#define CK_TILE_GLOBAL_ADDR
|
||||
#define CK_TILE_LDS_ADDR
|
||||
#define CK_TILE_BUF_RES_ADDR
|
||||
#endif
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
@@ -205,3 +218,8 @@
|
||||
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
|
||||
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
|
||||
#endif
|
||||
|
||||
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
|
||||
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
|
||||
#define CK_TILE_WORKAROUND_SWDEV_383542 1
|
||||
#endif
|
||||
|
||||
@@ -623,7 +623,7 @@ template <typename... Ys,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
@@ -635,7 +635,7 @@ template <typename... Ys,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
@@ -647,7 +647,7 @@ template <typename... Xs,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
|
||||
@@ -487,55 +487,12 @@ struct log2e<float>
|
||||
template <typename T = double>
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
float abs(const float& x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const float& x)
|
||||
{
|
||||
uint32_t xx = bit_cast<uint32_t>(x);
|
||||
return (xx & 0x7fffffff) > 0x7F800000;
|
||||
}
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp(float x) { return __ocml_exp_f32(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp(float x) { return std::expf(x); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float exp2(float x) { return std::exp2f(x); };
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float log(float x) { return __logf(x); };
|
||||
|
||||
CK_TILE_HOST
|
||||
float log(float x) { return std::logf(x); };
|
||||
|
||||
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
|
||||
{
|
||||
return __builtin_amdgcn_sad_u16(x, y, acc);
|
||||
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
|
||||
return (x > y ? (x - y) : (y - x)) + acc;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace ck_tile
|
||||
// blow function need data type pre-defined
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#ifndef __HIP_DEVICE_COMPILE__
|
||||
#include <cmath>
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
#if CK_TILE_WORKAROUND_SWDEV_383542
|
||||
extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
|
||||
#endif
|
||||
|
||||
// math functions for the host, some are implemented by calling C++ std functions
|
||||
|
||||
CK_TILE_HOST float abs(float x) { return std::abs(x); };
|
||||
|
||||
CK_TILE_HOST double abs(double x) { return std::abs(x); };
|
||||
|
||||
CK_TILE_HOST int8_t abs(int8_t x)
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
CK_TILE_HOST int32_t abs(int32_t x)
|
||||
{
|
||||
int32_t sgn = x >> (32 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
CK_TILE_HOST fp16_t abs(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
uint16_t abs_xx = xx & 0x7fff;
|
||||
|
||||
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
|
||||
|
||||
return abs_x;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_HOST int4_t abs(int4_t x)
|
||||
{
|
||||
int4_t sgn = x >> (4 - 1);
|
||||
return (x ^ sgn) - sgn;
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
|
||||
|
||||
CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
|
||||
|
||||
CK_TILE_HOST bool isnan(int8_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST bool isnan(int32_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST bool isnan(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_HOST bool isnan(int4_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST fp16_t sqrt(fp16_t x)
|
||||
{
|
||||
return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
|
||||
|
||||
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T tanh(T x)
|
||||
{
|
||||
return type_convert<T>(std::tanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float tanh<float>(float x)
|
||||
{
|
||||
return std::tanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double tanh<double>(double x)
|
||||
{
|
||||
return std::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T acos(T x)
|
||||
{
|
||||
return type_convert<T>(std::acosf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float acos<float>(float x)
|
||||
{
|
||||
return std::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double acos<double>(double x)
|
||||
{
|
||||
return std::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T neg(T x)
|
||||
{
|
||||
return type_convert<T>(-(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T atan(T x)
|
||||
{
|
||||
return type_convert<T>(std::atanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float atan<float>(float x)
|
||||
{
|
||||
return std::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double atan<double>(double x)
|
||||
{
|
||||
return std::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T sin(T x)
|
||||
{
|
||||
return type_convert<T>(std::sinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float sin<float>(float x)
|
||||
{
|
||||
return std::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double sin<double>(double x)
|
||||
{
|
||||
return std::sin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T asin(T x)
|
||||
{
|
||||
return type_convert<T>(std::asinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float asin<float>(float x)
|
||||
{
|
||||
return std::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double asin<double>(double x)
|
||||
{
|
||||
return std::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T asinh(T x)
|
||||
{
|
||||
return type_convert<T>(std::asinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float asinh<float>(float x)
|
||||
{
|
||||
return std::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double asinh<double>(double x)
|
||||
{
|
||||
return std::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T cos(T x)
|
||||
{
|
||||
return type_convert<T>(std::cosf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float cos<float>(float x)
|
||||
{
|
||||
return std::cosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double cos<double>(double x)
|
||||
{
|
||||
return std::cos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T acosh(T x)
|
||||
{
|
||||
return type_convert<T>(std::acoshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float acosh<float>(float x)
|
||||
{
|
||||
return std::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double acosh<double>(double x)
|
||||
{
|
||||
return std::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T tan(T x)
|
||||
{
|
||||
return type_convert<T>(std::tanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float tan<float>(float x)
|
||||
{
|
||||
return std::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double tan<double>(double x)
|
||||
{
|
||||
return std::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T atanh(T x)
|
||||
{
|
||||
return type_convert<T>(std::atanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float atanh<float>(float x)
|
||||
{
|
||||
return std::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double atanh<double>(double x)
|
||||
{
|
||||
return std::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T sinh(T x)
|
||||
{
|
||||
return type_convert<T>(std::sinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float sinh<float>(float x)
|
||||
{
|
||||
return std::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double sinh<double>(double x)
|
||||
{
|
||||
return std::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T ceil(T x)
|
||||
{
|
||||
return type_convert<T>(std::ceilf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float ceil<float>(float x)
|
||||
{
|
||||
return std::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double ceil<double>(double x)
|
||||
{
|
||||
return std::ceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T cosh(T x)
|
||||
{
|
||||
return type_convert<T>(std::coshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float cosh<float>(float x)
|
||||
{
|
||||
return std::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double cosh<double>(double x)
|
||||
{
|
||||
return std::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T floor(T x)
|
||||
{
|
||||
return type_convert<T>(std::floorf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float floor<float>(float x)
|
||||
{
|
||||
return std::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double floor<double>(double x)
|
||||
{
|
||||
return std::floor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T rcp(T x)
|
||||
{
|
||||
return type_convert<T>(1.f / type_convert<float>(x));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T exp(T x)
|
||||
{
|
||||
return type_convert<T>(std::expf(type_convert<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float exp<float>(float x)
|
||||
{
|
||||
return std::expf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double exp<double>(double x)
|
||||
{
|
||||
return std::exp(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T log(T x)
|
||||
{
|
||||
return type_convert<T>(std::logf(type_convert<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float log<float>(float x)
|
||||
{
|
||||
return std::logf(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double log<double>(double x)
|
||||
{
|
||||
return std::log(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T pow(T x, T gamma)
|
||||
{
|
||||
return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float pow<float>(float x, float gamma)
|
||||
{
|
||||
return std::powf(x, gamma);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double pow<double>(double x, double gamma)
|
||||
{
|
||||
return std::pow(x, gamma);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST T expm1(T x)
|
||||
{
|
||||
return type_convert<T>(std::expm1f(type_convert<float>(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST float expm1<float>(float x)
|
||||
{
|
||||
return std::expm1f(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST double expm1<double>(double x)
|
||||
{
|
||||
return std::expm1(x);
|
||||
}
|
||||
|
||||
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
|
||||
|
||||
CK_TILE_DEVICE float abs(float x)
|
||||
{
|
||||
union
|
||||
{
|
||||
float f32;
|
||||
uint32_t u32;
|
||||
} y;
|
||||
y.f32 = x;
|
||||
y.u32 = y.u32 & 0x7fffffff;
|
||||
return y.f32;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
|
||||
|
||||
CK_TILE_DEVICE int8_t abs(int8_t x)
|
||||
{
|
||||
int8_t sgn = x >> (8 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE int32_t abs(int32_t x)
|
||||
{
|
||||
int32_t sgn = x >> (32 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_DEVICE int4_t abs(int4_t x)
|
||||
{
|
||||
int4_t sgn = x >> (4 - 1);
|
||||
|
||||
return (x ^ sgn) - sgn;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_DEVICE fp16_t abs(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
uint16_t abs_xx = xx & 0x7fff;
|
||||
|
||||
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
|
||||
|
||||
return abs_x;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
|
||||
|
||||
CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
|
||||
|
||||
CK_TILE_DEVICE bool isnan(int8_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE bool isnan(int32_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
CK_TILE_DEVICE bool isnan(int4_t x)
|
||||
{
|
||||
(void)x;
|
||||
return false;
|
||||
};
|
||||
#endif
|
||||
|
||||
CK_TILE_DEVICE bool isnan(fp16_t x)
|
||||
{
|
||||
uint16_t xx = bit_cast<uint16_t>(x);
|
||||
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
|
||||
{
|
||||
return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh(T x)
|
||||
{
|
||||
return type_convert<T>(::tanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh<float>(float x)
|
||||
{
|
||||
return ::tanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double tanh<double>(double x)
|
||||
{
|
||||
return ::tanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T acos(T x)
|
||||
{
|
||||
return type_convert<T>(::acosf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float acos<float>(float x)
|
||||
{
|
||||
return ::acosf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double acos<double>(double x)
|
||||
{
|
||||
return ::acos(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T neg(T x)
|
||||
{
|
||||
return type_convert<T>(-(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float neg<float>(float x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double neg<double>(double x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE int32_t neg<int32_t>(int32_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
|
||||
{
|
||||
return -x;
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
|
||||
{
|
||||
return __hneg(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T atan(T x)
|
||||
{
|
||||
return type_convert<T>(::atanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float atan<float>(float x)
|
||||
{
|
||||
return ::atanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double atan<double>(double x)
|
||||
{
|
||||
return ::atan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T sin(T x)
|
||||
{
|
||||
return type_convert<T>(::sinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float sin<float>(float x)
|
||||
{
|
||||
return ::sinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double sin<double>(double x)
|
||||
{
|
||||
return ::sin(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
|
||||
{
|
||||
return ::hsin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T asin(T x)
|
||||
{
|
||||
return type_convert<T>(::asinf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float asin<float>(float x)
|
||||
{
|
||||
return ::asinf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double asin<double>(double x)
|
||||
{
|
||||
return ::asin(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T asinh(T x)
|
||||
{
|
||||
return type_convert<T>(::asinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float asinh<float>(float x)
|
||||
{
|
||||
return ::asinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double asinh<double>(double x)
|
||||
{
|
||||
return ::asinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T acosh(T x)
|
||||
{
|
||||
return type_convert<T>(::acoshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float acosh<float>(float x)
|
||||
{
|
||||
return ::acoshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double acosh<double>(double x)
|
||||
{
|
||||
return ::acosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tan(T x)
|
||||
{
|
||||
return type_convert<T>(::tanf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tan<float>(float x)
|
||||
{
|
||||
return ::tanf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double tan<double>(double x)
|
||||
{
|
||||
return ::tan(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T atanh(T x)
|
||||
{
|
||||
return type_convert<T>(::atanhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float atanh<float>(float x)
|
||||
{
|
||||
return ::atanhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double atanh<double>(double x)
|
||||
{
|
||||
return ::atanh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T sinh(T x)
|
||||
{
|
||||
return type_convert<T>(::sinhf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float sinh<float>(float x)
|
||||
{
|
||||
return ::sinhf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double sinh<double>(double x)
|
||||
{
|
||||
return ::sinh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T ceil(T x)
|
||||
{
|
||||
return type_convert<T>(::ceilf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float ceil<float>(float x)
|
||||
{
|
||||
return ::ceilf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double ceil<double>(double x)
|
||||
{
|
||||
return ::ceil(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
|
||||
{
|
||||
return ::hceil(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T cosh(T x)
|
||||
{
|
||||
return type_convert<T>(::coshf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float cosh<float>(float x)
|
||||
{
|
||||
return ::coshf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double cosh<double>(double x)
|
||||
{
|
||||
return ::cosh(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T floor(T x)
|
||||
{
|
||||
return type_convert<T>(::floorf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float floor<float>(float x)
|
||||
{
|
||||
return ::floorf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double floor<double>(double x)
|
||||
{
|
||||
return ::floor(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
|
||||
{
|
||||
return ::hfloor(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T rcp(T x)
|
||||
{
|
||||
#if !CK_TILE_WORKAROUND_SWDEV_383542
|
||||
return __frcp_rn(x);
|
||||
#else
|
||||
// return __ocml_native_recip_f32(x);
|
||||
return __builtin_amdgcn_rcpf(x);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T exp(T x)
|
||||
{
|
||||
return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
|
||||
{
|
||||
return hexp(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float exp<float>(float x)
|
||||
{
|
||||
return __ocml_exp_f32(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double exp<double>(double x)
|
||||
{
|
||||
return exp(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T log(T x)
|
||||
{
|
||||
return type_convert<T>(__logf(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
|
||||
{
|
||||
return hlog(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float log<float>(float x)
|
||||
{
|
||||
return __logf(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double log<double>(double x)
|
||||
{
|
||||
return log(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T pow(T x, T gamma)
|
||||
{
|
||||
return type_convert<T>(powf(type_convert<float>(x), type_convert<float>(gamma)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float pow<float>(float x, float gamma)
|
||||
{
|
||||
return powf(x, gamma);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double pow<double>(double x, double gamma)
|
||||
{
|
||||
return pow(x, gamma);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T expm1(T x)
|
||||
{
|
||||
return type_convert<T>(expm1f(type_convert<float>(x)));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float expm1<float>(float x)
|
||||
{
|
||||
return expm1f(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE double expm1<double>(double x)
|
||||
{
|
||||
return expm1(x);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i]);
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
|
||||
t_per_x,
|
||||
Coherence,
|
||||
oob_conditional_check>(
|
||||
p_data_, i, is_valid_element, buffer_size_);
|
||||
p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
|
||||
remove_cvref_t<T>,
|
||||
t_per_x,
|
||||
Coherence,
|
||||
oob_conditional_check>(
|
||||
p_data_, i, is_valid_element, buffer_size_, invalid_element_value_);
|
||||
oob_conditional_check>(p_data_,
|
||||
i + linear_offset,
|
||||
is_valid_element,
|
||||
buffer_size_,
|
||||
invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i]);
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i + linear_offset]);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
@@ -352,7 +359,8 @@ struct buffer_view<address_space_enum::global,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
|
||||
index_t i,
|
||||
index_t v_offset,
|
||||
index_t i_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -366,7 +374,38 @@ struct buffer_view<address_space_enum::global,
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
|
||||
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
|
||||
dst, cached_buf_res_, v_offset, i_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto async_get(CK_TILE_LDS_ADDR remove_cvref_t<T>* smem,
|
||||
index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X is vector of T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem,
|
||||
cached_buf_res_,
|
||||
i,
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
@@ -378,6 +417,7 @@ struct buffer_view<address_space_enum::global,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
|
||||
index_t i,
|
||||
index_t linear_offset,
|
||||
bool /*is_valid_element*/,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -391,7 +431,7 @@ struct buffer_view<address_space_enum::global,
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
|
||||
smem, cached_buf_res_, i, linear_offset, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
@@ -401,25 +441,25 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add<X>(i, is_valid_element, x);
|
||||
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
this->template atomic_max<X>(i, is_valid_element, x);
|
||||
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
// tmp += x;
|
||||
// this->template set<X>(i, is_valid_element, tmp);
|
||||
}
|
||||
@@ -432,7 +472,7 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -453,7 +493,7 @@ struct buffer_view<address_space_enum::global,
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cvref_t<T>, t_per_x, Coherence>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -462,9 +502,9 @@ struct buffer_view<address_space_enum::global,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -477,7 +517,7 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -489,7 +529,7 @@ struct buffer_view<address_space_enum::global,
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
amd_buffer_store_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
@@ -497,7 +537,8 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void
|
||||
atomic_add(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
|
||||
@@ -532,13 +573,13 @@ struct buffer_view<address_space_enum::global,
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
atomic_add_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -548,7 +589,8 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void
|
||||
atomic_max(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -572,11 +614,11 @@ struct buffer_view<address_space_enum::global,
|
||||
if constexpr(use_amd_buffer_addressing)
|
||||
{
|
||||
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, buffer_size_);
|
||||
x, p_data_, i + linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
else if(is_valid_element)
|
||||
{
|
||||
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i], x);
|
||||
atomic_max_g<remove_cvref_t<T>, t_per_x>(&p_data_[i + linear_offset], x);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -668,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -684,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp;
|
||||
|
||||
__builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
|
||||
__builtin_memcpy(&tmp, &(p_data_[i + linear_offset]), sizeof(X));
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
// using buf_t = ushort __attribute__((ext_vector_type(8)));
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]);
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i + linear_offset]);
|
||||
return bit_cast<X>(rtn);
|
||||
#endif
|
||||
}
|
||||
@@ -708,6 +752,23 @@ struct buffer_view<address_space_enum::lds,
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
|
||||
index_t v_offset,
|
||||
index_t i_offset,
|
||||
bool /*is_valid_element*/,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
smem_load<sizeof(X)>{}(dst, v_offset * sizeof(T), i_offset * sizeof(T));
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
@@ -715,17 +776,17 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -735,7 +796,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -751,6 +812,7 @@ struct buffer_view<address_space_enum::lds,
|
||||
bool constexpr workaround_int8_ds_write_issue = false;
|
||||
#endif
|
||||
|
||||
i += linear_offset; // simplicity
|
||||
if constexpr(std::is_same<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
int8_t>::value &&
|
||||
workaround_int8_ds_write_issue)
|
||||
@@ -952,8 +1014,10 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
get(index_t i, bool is_valid_element, bool_constant<oob_conditional_check> = {}) const
|
||||
CK_TILE_DEVICE constexpr auto get(index_t i,
|
||||
index_t /*linear_offset*/,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -995,17 +1059,17 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, is_valid_element, x);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, is_valid_element);
|
||||
this->template set<X>(i, is_valid_element, x + tmp);
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1015,7 +1079,7 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void set(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
@@ -1030,9 +1094,9 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
|
||||
X tmp = x;
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
__builtin_memcpy(&(p_data_[i + linear_offset]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i + linear_offset]) = x;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/null_tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/null_tensor.hpp"
|
||||
|
||||
@@ -28,7 +29,21 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
|
||||
NumCoord>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -46,7 +61,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -66,7 +101,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
|
||||
@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
|
||||
|
||||
// get input vectors
|
||||
static_for<0, num_vec_in, 1>{}([&](auto i) {
|
||||
constexpr auto idx_y_in = generate_array(
|
||||
constexpr auto idx_y_in = generate_tuple(
|
||||
[&](auto ii) {
|
||||
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
|
||||
},
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -72,7 +73,7 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor);
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
@@ -87,7 +88,33 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor);
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_>
|
||||
CK_TILE_DEVICE void store_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
{
|
||||
tile_window.store_raw(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -16,6 +16,24 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* tensor_view
|
||||
* abstract the underneath memory buffer(global, LDS, etc...)
|
||||
* and provide a unified get/set function for access
|
||||
*
|
||||
* For addressing into the buffer we use 2 variable to control:
|
||||
* coord : ND tensor coordinate, will calculate the actual offset inside
|
||||
* linear_offset : 1D offset, will be used in the immediate field of
|
||||
* the buffer instruction to help reduce register usage
|
||||
*
|
||||
* User can use either of the field, or both to indexing into the tensor
|
||||
*
|
||||
* We usually provide 2 set of API for buffer get/set, e.g.
|
||||
* get_vectorized_elements()/get_vectorized_elements_raw()
|
||||
* the former usually will call intrinsic or normal C function, the later
|
||||
* usually will call inline-asm function
|
||||
*
|
||||
*/
|
||||
template <typename BufferView_,
|
||||
typename TensorDesc_,
|
||||
memory_operation_enum DstInMemOp_ = memory_operation_enum::set>
|
||||
@@ -49,22 +67,6 @@ struct tensor_view
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; }
|
||||
|
||||
#if 0
|
||||
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
|
||||
{
|
||||
return buf_.template get<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
|
||||
{
|
||||
buf_.template set<DataType>(
|
||||
coord.get_offset(),
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
#endif
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -75,14 +77,34 @@ struct tensor_view
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<X>
|
||||
get_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element, // flag
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
return buf_.template get<X>(coord.get_offset(),
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -94,12 +116,90 @@ struct tensor_view
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
|
||||
dst, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset) const
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements(CK_TILE_LDS_ADDR remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element) const
|
||||
{
|
||||
return buf_.template async_get<X>(smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
is_valid_element,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem,
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
@@ -110,11 +210,15 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
|
||||
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
async_get_vectorized_elements_raw(remove_cvref_t<DataType>* smem,
|
||||
const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
return buf_.template async_get_raw<X>(
|
||||
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
|
||||
smem, coord.get_offset(), linear_offset, is_valid_element, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
@@ -125,11 +229,15 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
@@ -140,15 +248,53 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
set_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template set_raw<X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
@@ -157,15 +303,36 @@ struct tensor_view
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void update_vectorized_elements(
|
||||
const TensorCoord& coord, const X& x, bool_constant<oob_conditional_check> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
buf_.template update<DstInMemOp, X, oob_conditional_check>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Note: this tile window do not support single issue
|
||||
// you need to use tile_window_linear structure for this purpose
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -41,6 +43,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static_assert(NumCoord == 1);
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
@@ -189,7 +192,8 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -222,10 +226,11 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const AdaptorTopIndex& idx_diff_adaptor_top) const
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
@@ -279,10 +284,11 @@ struct tile_window_with_static_distribution
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; }
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(bool_constant<oob_conditional_check> = {}) const
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
@@ -308,11 +314,11 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -338,8 +344,9 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -350,8 +357,12 @@ struct tile_window_with_static_distribution
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
template <typename DstTile,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -397,6 +408,7 @@ struct tile_window_with_static_distribution
|
||||
get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
|
||||
dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
|
||||
bottom_tensor_thread_coord,
|
||||
0 /**/,
|
||||
bool_constant<oob_conditional_check>{},
|
||||
pre_nop_);
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
|
||||
@@ -409,23 +421,24 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
|
||||
asm volatile("; this inline asm is workaround to prevent compiler from using too much "
|
||||
"scratch memory" ::);
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: currently async load only implemented in inline asm
|
||||
template <typename LdsTileWindow_, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
@@ -467,7 +480,7 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
@@ -482,15 +495,16 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, pre_nop_);
|
||||
smem, bottom_tensor_thread_coord, 0, pre_nop_);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -501,8 +515,81 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
|
||||
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
|
||||
// check?)
|
||||
constexpr index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{}));
|
||||
|
||||
constexpr index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
constexpr index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// TODO: we force CK_TILE_LDS_ADDR
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem =
|
||||
lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// read from bottom tensor
|
||||
get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem, bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
smem += size_per_issue; // Note we manually increase the per-issue offset
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
@@ -515,7 +602,6 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
@@ -530,7 +616,7 @@ struct tile_window_with_static_distribution
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -548,15 +634,19 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -565,8 +655,9 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor) const
|
||||
template <index_t i_access_unsupport_ = -1>
|
||||
CK_TILE_DEVICE void store_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
@@ -591,7 +682,7 @@ struct tile_window_with_static_distribution
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -606,15 +697,16 @@ struct tile_window_with_static_distribution
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view()
|
||||
.template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
|
||||
bottom_tensor_thread_coord, vec_value);
|
||||
bottom_tensor_thread_coord, 0, vec_value);
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -623,8 +715,9 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <bool oob_conditional_check = true>
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
@@ -650,7 +743,7 @@ struct tile_window_with_static_distribution
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
@@ -666,15 +759,19 @@ struct tile_window_with_static_distribution
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, vec_value, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys =
|
||||
container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -746,7 +843,8 @@ struct tile_window_with_static_distribution
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(array<index_t, NDimP>{0}, idx_diff_ys);
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
@@ -798,6 +896,27 @@ make_tile_window(const TensorView_& tensor_view,
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
}
|
||||
|
||||
// this version can't be called in a constexpr context
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE auto
|
||||
make_tile_window_raw(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution};
|
||||
w.init_raw();
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
@@ -922,6 +1041,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
|
||||
tile_distribution);
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution)
|
||||
{
|
||||
auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution);
|
||||
w.init_raw();
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
|
||||
|
||||
1082
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
1082
include/ck_tile/core/tensor/tile_window_linear.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -59,8 +59,16 @@ struct magic_division32_bit_range
|
||||
CK_TILE_DEVICE static constexpr uint32_t
|
||||
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t tmp = __umulhi(dividend, multiplier);
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr uint32_t
|
||||
@@ -77,9 +85,18 @@ struct magic_division32_bit_range
|
||||
CK_TILE_DEVICE static constexpr int32_t
|
||||
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
if(__builtin_is_constant_evaluated())
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = __umulhi(dividend_u32, multiplier);
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr int32_t
|
||||
|
||||
@@ -24,5 +24,6 @@
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
#include "ck_tile/host/reference/reference_reduce.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
@@ -41,6 +42,73 @@ struct FillUniformDistribution
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
// clang-format off
|
||||
template<index_t bytes> struct RawIntegerType_ {};
|
||||
template<> struct RawIntegerType_<1> { using type = uint8_t;};
|
||||
template<> struct RawIntegerType_<2> { using type = uint16_t;};
|
||||
template<> struct RawIntegerType_<4> { using type = uint32_t;};
|
||||
template<> struct RawIntegerType_<8> { using type = uint64_t;};
|
||||
// clang-format on
|
||||
|
||||
template <typename T>
|
||||
using RawIntegerType = typename RawIntegerType_<sizeof(T)>::type;
|
||||
} // namespace impl
|
||||
|
||||
// Note: this struct will have no const-ness will generate random
|
||||
template <typename T>
|
||||
struct FillUniformDistribution_Unique
|
||||
{
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
|
||||
std::mt19937 gen_{};
|
||||
std::unordered_set<impl::RawIntegerType<T>> set_{};
|
||||
|
||||
FillUniformDistribution_Unique(float a = -5.f,
|
||||
float b = 5.f,
|
||||
std::optional<uint32_t> seed = {11939})
|
||||
: a_(a),
|
||||
b_(b),
|
||||
seed_(seed),
|
||||
gen_{seed_.has_value() ? *seed_ : std::random_device{}()},
|
||||
set_{}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last)
|
||||
{
|
||||
std::mt19937& gen = gen_;
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
auto& set = set_;
|
||||
std::generate(first, last, [&dis, &gen, &set]() {
|
||||
T v = static_cast<T>(0);
|
||||
do
|
||||
{
|
||||
v = ck_tile::type_convert<T>(dis(gen));
|
||||
} while(set.count(bit_cast<impl::RawIntegerType<T>>(v)) == 1);
|
||||
set.insert(bit_cast<impl::RawIntegerType<T>>(v));
|
||||
|
||||
return v;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range)
|
||||
-> std::void_t<decltype(std::declval<FillUniformDistribution_Unique&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
|
||||
void clear() { set_.clear(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillNormalDistribution
|
||||
{
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
@@ -545,6 +546,28 @@ struct HostTensor
|
||||
|
||||
typename Data::size_type size() const { return mData.size(); }
|
||||
|
||||
// return a slice of this tensor
|
||||
// for simplicity we just copy the data and return a new tensor
|
||||
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
|
||||
{
|
||||
assert(s_begin.size() == s_end.size());
|
||||
assert(s_begin.size() == get_num_of_dimension());
|
||||
|
||||
std::vector<size_t> s_len(s_begin.size());
|
||||
std::transform(
|
||||
s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
|
||||
HostTensor<T> sliced_tensor(s_len);
|
||||
|
||||
sliced_tensor.ForEach([&](auto& self, auto idx) {
|
||||
std::vector<size_t> src_idx(idx.size());
|
||||
std::transform(
|
||||
idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
|
||||
self(idx) = operator()(src_idx);
|
||||
});
|
||||
|
||||
return sliced_tensor;
|
||||
}
|
||||
|
||||
template <typename U = T>
|
||||
auto AsSpan() const
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,43 +9,81 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename AccDataType, typename BDataType>
|
||||
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
|
||||
HostTensor<BDataType>& b_m_n)
|
||||
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
|
||||
CK_TILE_HOST void
|
||||
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
|
||||
{
|
||||
auto f = [&](auto m) {
|
||||
const int N = a_m_n.mDesc.get_lengths()[1];
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(rank == y.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
|
||||
index_t target_dim = dim == -1 ? (rank - 1) : dim;
|
||||
index_t softmax_len = x.get_length(target_dim);
|
||||
index_t n_parallel = x.get_element_size() / softmax_len;
|
||||
auto x_len = x.get_lengths();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> coord = [&]() {
|
||||
std::vector<size_t> t_(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--)
|
||||
{
|
||||
if(i == target_dim)
|
||||
continue;
|
||||
t_[i] = r % x_len[i];
|
||||
r = r / x_len[i];
|
||||
}
|
||||
return t_;
|
||||
}();
|
||||
|
||||
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
|
||||
|
||||
// compute max
|
||||
for(auto idx = 0; idx < softmax_len; idx++)
|
||||
{
|
||||
const ADataType v_a = a_m_n(m, n);
|
||||
|
||||
v_max = v_max < v_a ? v_a : v_max;
|
||||
auto c_ = coord;
|
||||
c_[target_dim] = idx;
|
||||
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
|
||||
v_max = v_max < v_x ? v_x : v_max;
|
||||
}
|
||||
|
||||
AccDataType v_exp_sum = 0;
|
||||
ComputeType v_exp_sum = static_cast<ComputeType>(0);
|
||||
|
||||
// sum
|
||||
for(int n = 0; n < N; ++n)
|
||||
for(auto idx = 0; idx < softmax_len; idx++)
|
||||
{
|
||||
const ADataType v_a = a_m_n(m, n);
|
||||
auto c_ = coord;
|
||||
c_[target_dim] = idx;
|
||||
|
||||
v_exp_sum += ck_tile::exp(v_a - v_max);
|
||||
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
|
||||
|
||||
v_exp_sum += ck_tile::exp(v_x - v_max);
|
||||
}
|
||||
|
||||
// elementwise
|
||||
for(int n = 0; n < N; ++n)
|
||||
for(auto idx = 0; idx < softmax_len; idx++)
|
||||
{
|
||||
const ADataType v_a = a_m_n(m, n);
|
||||
auto c_ = coord;
|
||||
c_[target_dim] = idx;
|
||||
|
||||
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
|
||||
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
|
||||
|
||||
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
|
||||
|
||||
y(c_) = ck_tile::type_convert<OutputType>(out);
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f,
|
||||
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
|
||||
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
|
||||
{
|
||||
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
|
||||
|
||||
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
|
||||
|
||||
return y;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
124
include/ck_tile/host/reference/reference_topk.hpp
Normal file
124
include/ck_tile/host/reference/reference_topk.hpp
Normal file
@@ -0,0 +1,124 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
similiar to torch.topk()
|
||||
x (Tensor) – the input tensor.
|
||||
k (int) – the k in “top-k”
|
||||
dim (int, optional) – the dimension to sort along
|
||||
largest (bool, optional) – largest or smallest elements
|
||||
sorted (bool, optional) – elements in sorted order or not
|
||||
|
||||
output:
|
||||
y_values
|
||||
y_indices
|
||||
|
||||
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
|
||||
*/
|
||||
template <typename DataType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
|
||||
HostTensor<DataType>& y_values,
|
||||
HostTensor<IndexType>& y_indices,
|
||||
index_t k,
|
||||
index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
// rank must be the same
|
||||
index_t rank = x.get_num_of_dimension();
|
||||
assert(rank == y_values.get_num_of_dimension());
|
||||
assert(rank == y_indices.get_num_of_dimension());
|
||||
assert(dim == -1 || dim < rank);
|
||||
|
||||
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
|
||||
index_t topk_src_len = x.get_length(topk_dim);
|
||||
auto x_len = x.get_lengths();
|
||||
|
||||
assert(k <= topk_src_len);
|
||||
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
|
||||
|
||||
index_t n_parallel = x.get_element_size() / topk_src_len;
|
||||
|
||||
// clang-format off
|
||||
auto f = [&](auto i_element) {
|
||||
std::vector<size_t> topk_coord = [&](){
|
||||
std::vector<size_t> t_(rank, 0);
|
||||
size_t r = i_element;
|
||||
for(index_t i = rank - 1; i >= 0; i--) {
|
||||
if(i == topk_dim) continue; // topk dim should be zero
|
||||
t_[i] = r % x_len[i]; r = r / x_len[i];
|
||||
}
|
||||
return t_;
|
||||
}();
|
||||
|
||||
using elem_t = std::pair<DataType, IndexType>;
|
||||
std::vector<elem_t> q = [&](){
|
||||
std::vector<elem_t> t_(topk_src_len);
|
||||
for(index_t i = 0; i < topk_src_len; i++) {
|
||||
auto c_ = topk_coord; c_[topk_dim] = i;
|
||||
t_[i].first = x(c_); t_[i].second = i;
|
||||
}
|
||||
return t_;
|
||||
}();
|
||||
|
||||
// run topk
|
||||
if(largest) {
|
||||
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
|
||||
if(sorted) {
|
||||
std::sort(q.begin(), q.begin() + k - 1,
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
|
||||
}
|
||||
} else {
|
||||
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
|
||||
if(sorted) {
|
||||
std::sort(q.begin(), q.begin() + k - 1,
|
||||
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
|
||||
}
|
||||
}
|
||||
|
||||
// write out
|
||||
for(index_t i = 0; i < k; i++) {
|
||||
auto c_ = topk_coord; c_[topk_dim] = i;
|
||||
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
|
||||
}
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// TODO: if using this method, the return tensor would be dense(no stride)
|
||||
template <typename DataType, typename IndexType = index_t>
|
||||
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
|
||||
index_t k,
|
||||
index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
auto lens = x.get_lengths();
|
||||
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
|
||||
assert(target_dim < lens.size());
|
||||
assert(k <= lens[target_dim]);
|
||||
lens[target_dim] = k;
|
||||
HostTensor<DataType> y_values(lens);
|
||||
HostTensor<IndexType> y_indices(lens);
|
||||
|
||||
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
|
||||
|
||||
return ck_tile::make_tuple(y_values, y_indices);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
7
include/ck_tile/ops/elementwise.hpp
Normal file
7
include/ck_tile/ops/elementwise.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
1163
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
Normal file
1163
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
|
||||
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
|
||||
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
|
||||
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
|
||||
|
||||
@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
async_load_fence(k_dram_window.get_num_access());
|
||||
async_load_fence(k_dram_window.get_num_of_access());
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
gemm_0(s_acc,
|
||||
|
||||
@@ -4,9 +4,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* TODO: block_tile_reduce_sync() currently has a limitation
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
|
||||
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
* this version is faster, using xor to do reduce, no need broadcast anymore
|
||||
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
|
||||
*/
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// loop over thread data
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local = acc_tensor.get_thread_buffer()[i];
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// xor
|
||||
index_t src_lane =
|
||||
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
acc_tensor.get_thread_buffer()(i) = v_local;
|
||||
});
|
||||
}
|
||||
|
||||
// FIXME: this is for 2D to 1D reduce only, need to support n-D
|
||||
template <typename AccDistributedTensor_,
|
||||
typename InDistributedTensor_,
|
||||
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* TODO: block_tile_reduce() currently has a limitation
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
template <typename AccDataType_,
|
||||
typename InDistributedTensor_,
|
||||
index_t... InReduceDims,
|
||||
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
|
||||
return acc_tensor;
|
||||
}
|
||||
|
||||
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
|
||||
// this version only support in/acc/out datatypes are the same
|
||||
// this version will call thread/warp+sync in one function call
|
||||
//
|
||||
template <typename InDistributedTensor_>
|
||||
struct BlockReduce2D
|
||||
{
|
||||
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
|
||||
using InDataType = typename InDistributedTensor::DataType;
|
||||
|
||||
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
|
||||
: t(t_), reduce_init(reduce_init_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
|
||||
{
|
||||
using ReduceDim = sequence<1>; // hard coded
|
||||
constexpr auto acc_dstr =
|
||||
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
|
||||
InDistributedTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
ReduceDim{}));
|
||||
|
||||
return make_static_distributed_tensor<InDataType>(acc_dstr);
|
||||
}
|
||||
|
||||
// return number of pixels each lane need to reduce
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
|
||||
{
|
||||
constexpr auto spans = InDistributedTensor::get_distributed_spans();
|
||||
}
|
||||
|
||||
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
|
||||
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
|
||||
// internally
|
||||
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
|
||||
// element , and the first element is always ignored For simplicity, will always try from
|
||||
// right-to-left to find alone which Y dim to split
|
||||
template <typename ReduceFunc,
|
||||
typename ReduceSyncFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
|
||||
const ReduceSyncFunc& reduce_sync_func,
|
||||
ReducePacksPerXDim = {}) const
|
||||
{
|
||||
constexpr auto spans = InDistributedTensor::get_distributed_spans();
|
||||
|
||||
constexpr auto row_y_unpacks = [&]() {
|
||||
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
|
||||
constexpr auto row_y_size =
|
||||
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
|
||||
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
|
||||
|
||||
static_assert(row_y_size % row_y_packs == 0);
|
||||
|
||||
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
|
||||
|
||||
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
|
||||
constexpr auto unpacks = slice_info[number<1>{}];
|
||||
return unpacks;
|
||||
}();
|
||||
|
||||
auto acc_tensor = MakeDstBlockTile();
|
||||
|
||||
// in-thread reduction
|
||||
// FIXME: hard coded to be 2D to 1D reduction
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
|
||||
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto acc = acc_tensor[acc_dstr_idx];
|
||||
|
||||
sweep_tile_uspan(
|
||||
spans[number<1>{}],
|
||||
[&](auto... dstr_idx_i1) {
|
||||
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
|
||||
},
|
||||
row_y_unpacks);
|
||||
|
||||
acc_tensor(acc_dstr_idx) = acc;
|
||||
});
|
||||
|
||||
// TODO: always use xor to do cross-lane reduce
|
||||
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
|
||||
|
||||
return acc_tensor;
|
||||
}
|
||||
|
||||
template <typename ReduceFunc>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
|
||||
{
|
||||
return operator()(reduce_func, reduce_func);
|
||||
}
|
||||
|
||||
InDistributedTensor t;
|
||||
InDataType reduce_init;
|
||||
};
|
||||
|
||||
// deduction guide
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
8
include/ck_tile/ops/softmax.hpp
Normal file
8
include/ck_tile/ops/softmax.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
81
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
Normal file
81
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
Normal file
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
simple 2d softmax implementation, along row (dim=1)
|
||||
requirement:
|
||||
1). each row is within a warp
|
||||
2). data type must be a dword
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockSoftmax2D
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = typename Problem::DataType;
|
||||
|
||||
template <typename DistributedTensor, index_t dim = 1>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
|
||||
{
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
const auto f_max3 = [](auto e0, auto e1, auto e2) {
|
||||
float rtn;
|
||||
asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
|
||||
return rtn;
|
||||
};
|
||||
const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
|
||||
#endif
|
||||
|
||||
// compute row max
|
||||
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
|
||||
#else
|
||||
auto row_max = reduce_row_max(f_max);
|
||||
#endif
|
||||
sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
y(idx) = exp(x[idx] - row_max[row_id]);
|
||||
});
|
||||
|
||||
// compute row sum
|
||||
auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
|
||||
#else
|
||||
auto row_sum = reduce_row_sum(f_sum);
|
||||
#endif
|
||||
// reciprocal
|
||||
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
|
||||
sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
|
||||
|
||||
// scale
|
||||
sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
y(idx) = y(idx) * r(row_id);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor, index_t dim = 1>
|
||||
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
|
||||
{
|
||||
auto y = DistributedTensor{}; // distributed tensor
|
||||
operator()(x, y, number<dim>{});
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_>
|
||||
struct BlockSoftmax2DProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
8
include/ck_tile/ops/topk.hpp
Normal file
8
include/ck_tile/ops/topk.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
113
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
Normal file
113
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
Normal file
@@ -0,0 +1,113 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
simple 2d topk implementation, along row (dim=1)
|
||||
requirement:
|
||||
1). each row is within a warp
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockTopkStream2D
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = typename Problem::DataType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
// TODO: if DataType is subdword, need pack into single dword to use argmax
|
||||
struct ArgmaxPacket
|
||||
{
|
||||
DataType arg;
|
||||
index_t value;
|
||||
};
|
||||
|
||||
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
|
||||
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
|
||||
const OutWindow& out_window,
|
||||
const IdxWindow& idx_window,
|
||||
index_t k,
|
||||
number<dim> = {})
|
||||
{
|
||||
OutWindow out_window_tmp = out_window;
|
||||
IdxWindow idx_window_tmp = idx_window;
|
||||
static_assert(
|
||||
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
|
||||
std::is_same_v<typename DistributedTensor::DataType, DataType>);
|
||||
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
|
||||
|
||||
DistributedTensor x_tmp = x;
|
||||
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
|
||||
|
||||
// argmax for topk
|
||||
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
|
||||
return e0.arg > e1.arg ? e0 : e1;
|
||||
};
|
||||
|
||||
for(index_t i_k = 0; i_k < k; i_k++)
|
||||
{
|
||||
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
|
||||
auto packet = [&]() {
|
||||
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket t;
|
||||
t.arg = x_tmp(i_j_idx); // !!! we reference x here
|
||||
t.value = tile_idx.at(number<1>{});
|
||||
tmp(i_j_idx) = t;
|
||||
});
|
||||
});
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
|
||||
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
|
||||
block_tile_reduce_xor_sync(r, f_argmax);
|
||||
|
||||
auto o = make_static_distributed_tensor<DataType>(dst_dist);
|
||||
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket tmp = r(i_j_idx);
|
||||
o(i_j_idx) = tmp.arg;
|
||||
i(i_j_idx) = tmp.value;
|
||||
});
|
||||
});
|
||||
|
||||
// update value
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
x.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto col_id = tile_idx.at(number<1>{});
|
||||
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
|
||||
: x_tmp(i_j_idx);
|
||||
});
|
||||
});
|
||||
|
||||
if(threadIdx.x % Problem::ColLanes == 0)
|
||||
{
|
||||
store_tile(out_window_tmp, o);
|
||||
store_tile(idx_window_tmp, i);
|
||||
}
|
||||
move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
|
||||
move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
simple 2d topk implementation, along row (dim=1)
|
||||
requirement:
|
||||
1). each row is within a warp
|
||||
*/
|
||||
template <typename DataType_, typename IndexType_, index_t ColLanes_>
|
||||
struct BlockTopkStream2DProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
static constexpr index_t ColLanes = ColLanes_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
10
include/ck_tile/ops/topk_softmax.hpp
Normal file
10
include/ck_tile/ops/topk_softmax.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
166
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
Normal file
166
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxHostArgs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
void* p_indices;
|
||||
index_t num_rows;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct TopkSoftmaxKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using InputType = typename Problem::InputType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
struct TopkSoftmaxKargs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
void* p_indices;
|
||||
index_t num_rows;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
};
|
||||
|
||||
using Kargs = TopkSoftmaxKargs;
|
||||
using Hargs = TopkSoftmaxHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
if constexpr(Problem::LaunchType > 0)
|
||||
{
|
||||
int num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return dim3(num_cu * Problem::LaunchType);
|
||||
}
|
||||
else
|
||||
{
|
||||
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
|
||||
const int num_blocks =
|
||||
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
|
||||
return dim3(num_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_input = h.p_input;
|
||||
k.p_output = h.p_output;
|
||||
k.p_indices = h.p_indices;
|
||||
k.num_rows = h.num_rows;
|
||||
k.num_experts = h.num_experts;
|
||||
k.topk = h.topk;
|
||||
k.stride_input = h.stride_input;
|
||||
k.stride_output = h.stride_output;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
|
||||
|
||||
if(block_row_id > kargs.num_rows)
|
||||
return;
|
||||
|
||||
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
|
||||
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
|
||||
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
|
||||
|
||||
const auto input_window = [&]() {
|
||||
const InputType* p_input =
|
||||
reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
|
||||
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(num_rows_rem, kargs.num_experts),
|
||||
make_tuple(kargs.stride_input, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
auto view = pad_tensor_view(
|
||||
tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
|
||||
|
||||
return make_tile_window(
|
||||
view,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
auto output_window = [&]() {
|
||||
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_output,
|
||||
make_tuple(num_rows_rem, kargs.topk),
|
||||
make_tuple(kargs.stride_output, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view =
|
||||
pad_tensor_view(tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
|
||||
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
|
||||
// 2. we loop over topk 1-1, no need padding
|
||||
return make_tile_window(
|
||||
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
|
||||
}();
|
||||
|
||||
auto indices_window = [&]() {
|
||||
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_indices,
|
||||
make_tuple(num_rows_rem, kargs.topk),
|
||||
make_tuple(kargs.stride_output, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view =
|
||||
pad_tensor_view(tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
|
||||
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
|
||||
// 2. we loop over topk 1-1, no need padding
|
||||
return make_tile_window(
|
||||
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
|
||||
}();
|
||||
|
||||
Pipeline{}(input_window,
|
||||
output_window,
|
||||
indices_window,
|
||||
kargs.num_rows,
|
||||
kargs.num_experts,
|
||||
kargs.topk,
|
||||
block_row_id);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
|
||||
struct TopkSoftmaxWarpPerRowPipeline
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
|
||||
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
|
||||
OutputWindow& out_window,
|
||||
IndexWindow& idx_window,
|
||||
index_t rows,
|
||||
index_t experts,
|
||||
index_t k,
|
||||
index_t block_row_id)
|
||||
{
|
||||
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
auto inp_win = make_tile_window_linear_raw(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
#else
|
||||
auto inp_win = make_tile_window_linear(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
#endif
|
||||
auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
|
||||
out_window.get_window_lengths(),
|
||||
out_window.get_window_origin(),
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
|
||||
idx_window.get_window_lengths(),
|
||||
idx_window.get_window_origin(),
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
auto softmax = Policy::template GetSoftmax<Problem>();
|
||||
auto topk = Policy::template GetTopk<Problem>();
|
||||
|
||||
const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
|
||||
|
||||
while(1)
|
||||
{
|
||||
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto x =
|
||||
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
|
||||
buffer_load_fence(number<0>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#else
|
||||
auto x = load_tile(inp_win);
|
||||
#endif
|
||||
// cast and pad input data
|
||||
auto w = [&]() {
|
||||
#if 0
|
||||
auto w_ = cast_tile<WeightType>(x);
|
||||
|
||||
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
w_.get_tile_distribution(), i_j_idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
// set to -INF if OOB so that later softmax can work properly
|
||||
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
|
||||
: w_(i_j_idx);
|
||||
});
|
||||
});
|
||||
return w_;
|
||||
#else
|
||||
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
|
||||
auto w_f = [&](auto idx) {
|
||||
w_(idx) = type_convert<WeightType>(x(idx));
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
w_(idx) =
|
||||
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
ts();
|
||||
return w_;
|
||||
#endif
|
||||
}();
|
||||
|
||||
// softmax
|
||||
auto y = softmax(w);
|
||||
|
||||
topk(y, out_win, idx_win, k);
|
||||
|
||||
// check exit
|
||||
if constexpr(Problem::LaunchType == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_row_id += grid_rows_per_loop;
|
||||
if(block_row_id >= rows)
|
||||
break;
|
||||
}
|
||||
|
||||
move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
|
||||
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
|
||||
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxWarpPerRowPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
// TODO: Y dim must have one dim that is not reduced
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<Problem::IssuesPerCol,
|
||||
Problem::WarpsPerBlock,
|
||||
Problem::RowsPerWarpPerColIssue>,
|
||||
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
|
||||
tuple<sequence<Problem::IssuesPerCol,
|
||||
Problem::WarpsPerBlock,
|
||||
Problem::RowsPerWarpPerColIssue>,
|
||||
sequence<1>>, // each row write out single element
|
||||
tuple<sequence<1>, sequence<1, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
|
||||
{
|
||||
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
|
||||
return BlockSoftmax2D<softmax_problem>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
|
||||
{
|
||||
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
|
||||
typename Problem::IndexType,
|
||||
Problem::LanesPerRow>;
|
||||
// Note: replicate is LanesPerRow
|
||||
return BlockTopkStream2D<topk_problem>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType_,
|
||||
typename WeightType_,
|
||||
typename IndexType_,
|
||||
index_t Experts_,
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
struct TopkSoftmaxWarpPerRowProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using InputType = remove_cvref_t<InputType_>;
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t LaunchType = LaunchType_;
|
||||
static constexpr index_t Experts = Experts_;
|
||||
static constexpr index_t BytesPerIssue = BytesPerIssue_;
|
||||
static constexpr index_t IssuesPerCol = IssuesPerCol_;
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
|
||||
static_assert(BytesPerIssue % sizeof(InputType) == 0);
|
||||
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
|
||||
static_assert(Experts % VectorSize == 0);
|
||||
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
|
||||
static_assert(WarpSize % LanesPerRow == 0);
|
||||
static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
|
||||
static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
|
||||
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
|
||||
|
||||
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
|
||||
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user