mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Merge branch 'develop' into hot_fix_address_space
This commit is contained in:
12
.github/CODEOWNERS
vendored
12
.github/CODEOWNERS
vendored
@@ -1,8 +1,8 @@
|
||||
* @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz
|
||||
# Documentation files
|
||||
docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj
|
||||
*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj
|
||||
*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz
|
||||
# Header directory for Doxygen documentation
|
||||
library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz
|
||||
|
||||
23
CHANGELOG.md
23
CHANGELOG.md
@@ -2,6 +2,29 @@
|
||||
|
||||
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 6.5.0
|
||||
|
||||
### Additions
|
||||
|
||||
None
|
||||
|
||||
### Optimizations
|
||||
|
||||
None
|
||||
|
||||
### Fixes
|
||||
|
||||
None
|
||||
|
||||
### Changes
|
||||
|
||||
* Removed support for gfx940 and gfx941 targets (#1944)
|
||||
* Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876)
|
||||
|
||||
### Known issues
|
||||
|
||||
None
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 6.1.0
|
||||
|
||||
### Additions
|
||||
|
||||
6
Jenkinsfile
vendored
6
Jenkinsfile
vendored
@@ -603,6 +603,10 @@ def Build_CK(Map conf=[:]){
|
||||
"""
|
||||
}
|
||||
}
|
||||
// set ownership of all files and folders to jenkins after all steps completed
|
||||
dir("build"){
|
||||
sh "sudo chown -R jenkins:jenkins ../*"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -858,7 +862,7 @@ pipeline {
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \
|
||||
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \
|
||||
-D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \
|
||||
-D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
|
||||
-D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
|
||||
-U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \
|
||||
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
|
||||
You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s) you want
|
||||
to run CK on. You can specify single or multiple architectures. If you specify multiple architectures,
|
||||
use a semicolon between each; for example, `gfx908;gfx90a;gfx940`.
|
||||
use a semicolon between each; for example, `gfx908;gfx90a;gfx942`.
|
||||
|
||||
```bash
|
||||
cmake \
|
||||
|
||||
@@ -13,7 +13,7 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
|
||||
|
||||
const std::unordered_set<std::string>& get_xdlop_archs()
|
||||
{
|
||||
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"};
|
||||
static std::unordered_set<std::string> supported_archs{"gfx90a", "gfx908", "gfx942"};
|
||||
return supported_archs;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942 gfx950)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
|
||||
@@ -109,9 +109,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -204,7 +204,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
|
||||
@@ -126,6 +126,6 @@ Note FA use bottom-right by default to express swa case, here we require you exp
|
||||
TBD
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+.
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later.
|
||||
|
||||
@@ -55,11 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
#endif
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
|
||||
|
||||
@@ -55,22 +55,19 @@ inline std::string get_device_name()
|
||||
inline bool is_xdl_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline bool is_lds_direct_load_supported()
|
||||
{
|
||||
// Check if direct loads from global memory to LDS are supported.
|
||||
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
|
||||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942" ||
|
||||
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx942" ||
|
||||
ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline bool is_bf16_atomic_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
}
|
||||
|
||||
inline bool is_gfx101_supported()
|
||||
|
||||
@@ -12,7 +12,11 @@
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp"
|
||||
#else
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseSparseEmbedding,
|
||||
typename EmbType,
|
||||
typename IndexType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutType,
|
||||
typename OutGridDesc,
|
||||
typename EmbElementwiseOperation,
|
||||
ck::index_t NumEmbeddings>
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
__global__ void kernel_sparse_embeddings_forward_layernorm(
|
||||
OutType* p_out,
|
||||
const ck::Array<EmbType*, NumEmbeddings> p_embs,
|
||||
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
const OutGridDesc out_grid_desc,
|
||||
const AccDataType epsilon,
|
||||
const EmbElementwiseOperation emb_elementwise_op)
|
||||
{
|
||||
GridwiseSparseEmbedding::Run(
|
||||
p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op);
|
||||
}
|
||||
|
||||
template <typename EmbType,
|
||||
typename IndexType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename AccDataType,
|
||||
typename OutType,
|
||||
typename OutGridDesc,
|
||||
typename EmbElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t DimClusterSize,
|
||||
ck::index_t RowClusterSize,
|
||||
ck::index_t DimPerBlock, // Row x Dim, along Dim
|
||||
ck::index_t RowPerBlock, // Row x Dim, along Row
|
||||
ck::index_t DimThreadSize, // this is actually not vector, but number of registers
|
||||
ck::index_t RowVectorSize,
|
||||
ck::index_t NumEmbeddings>
|
||||
struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static_assert(BlockSize == RowClusterSize * DimClusterSize,
|
||||
"Invalid cluster distribution within block");
|
||||
static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise");
|
||||
|
||||
static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, "");
|
||||
static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, "");
|
||||
|
||||
static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize);
|
||||
static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize);
|
||||
|
||||
static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), "");
|
||||
static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks;
|
||||
static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks;
|
||||
|
||||
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
|
||||
|
||||
using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
|
||||
|
||||
using ThreadwiseWelford =
|
||||
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
|
||||
|
||||
using ThreadClusterLength = Sequence<DimClusterSize, RowClusterSize>;
|
||||
|
||||
using BlockwiseWelford =
|
||||
BlockwiseWelford<AccDataType, BlockSize, ThreadClusterLength, Sequence<0, 1>>;
|
||||
|
||||
__device__ static void Run(OutType* p_out,
|
||||
const ck::Array<EmbType*, NumEmbeddings> p_embs,
|
||||
const ck::Array<IndexType*, NumEmbeddings> p_indexes,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
const OutGridDesc,
|
||||
const AccDataType epsilon,
|
||||
const EmbElementwiseOperation emb_elementwise_op)
|
||||
{
|
||||
const index_t thread_local_id = get_thread_local_1d_id();
|
||||
const index_t block_global_id = get_block_1d_id();
|
||||
|
||||
constexpr auto thread_cluster_desc =
|
||||
make_cluster_descriptor(Sequence<DimClusterSize, RowClusterSize>{}, Sequence<0, 1>{});
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
|
||||
|
||||
const auto thread_dim_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_row_cluster_id = thread_cluster_idx[I1];
|
||||
|
||||
const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize);
|
||||
|
||||
const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize;
|
||||
|
||||
auto threadwise_welford = ThreadwiseWelford();
|
||||
threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize;
|
||||
|
||||
constexpr auto thread_buf_size =
|
||||
DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize;
|
||||
constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize));
|
||||
constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize;
|
||||
constexpr auto mean_var_buf_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize));
|
||||
constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize;
|
||||
constexpr auto gamma_beta_buf_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize));
|
||||
|
||||
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true>,
|
||||
NumEmbeddings>
|
||||
in_thread_bufs;
|
||||
ck::Array<StaticBuffer<AddressSpaceEnum::Vgpr, IndexType, DimPerBlock, true>, NumEmbeddings>
|
||||
index_bufs;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, thread_buf_size, true> acc_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, gamma_beta_buf_size, true>
|
||||
gamma_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, gamma_beta_buf_size, true>
|
||||
beta_thread_buf;
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> mean_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr, AccDataType, mean_var_buf_size, true> var_thread_buf;
|
||||
|
||||
auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
ck::Array<vector_type_maker_t<EmbType, RowVectorSize>, NumEmbeddings> emb_vectors;
|
||||
auto emb_a = emb_vectors[0];
|
||||
using src_vector_t = typename decltype(emb_a)::type;
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_;
|
||||
|
||||
auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(EmbType) * RowVectorSize;
|
||||
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
|
||||
IndexType index = index_bufs[i_embedding_][Number<current_dim>{}];
|
||||
|
||||
__amdgpu_buffer_rsrc_t emb_res =
|
||||
make_wave_buffer_resource_with_default_range_new(p_embs[i_embedding_] +
|
||||
index * RowPerBlock);
|
||||
emb_vectors(i_embedding_).template AsType<src_vector_t>()(I0) =
|
||||
amd_buffer_load_impl<EmbType, RowVectorSize>(emb_res, thread_offset, 0);
|
||||
});
|
||||
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
|
||||
in_thread_bufs(i_embedding_)(Number<register_offset>{}) =
|
||||
ck::type_convert<AccDataType>(
|
||||
emb_vectors[i_embedding_].template AsType<EmbType>()[i_row_vec_]);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
auto in_data_refs = generate_tie(
|
||||
[&](auto i_embedding_) -> const auto& {
|
||||
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
|
||||
},
|
||||
Number<NumEmbeddings>{});
|
||||
auto out_data_refs = generate_tie(
|
||||
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
|
||||
Number<1>{});
|
||||
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
|
||||
threadwise_welford.cur_count_++;
|
||||
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
|
||||
var_thread_buf(Number<mean_var_offset>{}),
|
||||
acc_thread_buf(Number<register_offset>{}));
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
__amdgpu_buffer_rsrc_t out_res =
|
||||
make_wave_buffer_resource_with_default_range_new(p_out + index_start * RowPerBlock);
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
vector_type_maker_t<OutType, RowVectorSize> out_vector;
|
||||
using dst_vector_t = typename decltype(out_vector)::type;
|
||||
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
auto divisor =
|
||||
1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number<mean_var_offset>{}) + epsilon);
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto gamma_beta_offset =
|
||||
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
|
||||
|
||||
auto acc_val = acc_thread_buf[Number<register_offset>{}];
|
||||
acc_val = (acc_val - mean_thread_buf(Number<mean_var_offset>{})) * divisor;
|
||||
acc_val = acc_val * gamma_thread_buf[Number<gamma_beta_offset>{}] +
|
||||
beta_thread_buf[Number<gamma_beta_offset>{}];
|
||||
|
||||
out_vector.template AsType<OutType>()(Number<i_row_vec_>{}) =
|
||||
type_convert<OutType>(acc_val);
|
||||
});
|
||||
|
||||
index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(OutType) * RowVectorSize;
|
||||
|
||||
amd_buffer_store_impl<OutType, RowVectorSize>(
|
||||
out_vector.template AsType<dst_vector_t>()[Number<0>{}],
|
||||
out_res,
|
||||
thread_offset,
|
||||
0);
|
||||
});
|
||||
};
|
||||
|
||||
// first load index
|
||||
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
|
||||
// prefer use s_load
|
||||
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
|
||||
index_bufs(i_embedding_)(i_idx_) =
|
||||
p_indexes[i_embedding_][index_start + i_idx_.value];
|
||||
});
|
||||
});
|
||||
|
||||
// load gamma/beta
|
||||
static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) {
|
||||
vector_type_maker_t<GammaDataType, RowVectorSize> gamma_vector;
|
||||
vector_type_maker_t<BetaDataType, RowVectorSize> beta_vector;
|
||||
|
||||
index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(GammaDataType) * RowVectorSize;
|
||||
index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) *
|
||||
sizeof(BetaDataType) * RowVectorSize;
|
||||
|
||||
__amdgpu_buffer_rsrc_t gamma_res =
|
||||
make_wave_buffer_resource_with_default_range_new(p_gamma);
|
||||
__amdgpu_buffer_rsrc_t beta_res =
|
||||
make_wave_buffer_resource_with_default_range_new(p_beta);
|
||||
|
||||
gamma_vector.template AsType<typename decltype(gamma_vector)::type>()(I0) =
|
||||
amd_buffer_load_impl<GammaDataType, RowVectorSize>(
|
||||
gamma_res, thread_offset_gamma, 0);
|
||||
beta_vector.template AsType<typename decltype(beta_vector)::type>()(I0) =
|
||||
amd_buffer_load_impl<BetaDataType, RowVectorSize>(beta_res, thread_offset_beta, 0);
|
||||
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto offset =
|
||||
gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_));
|
||||
gamma_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
|
||||
gamma_vector.template AsType<GammaDataType>()[Number<i_row_vec_>{}]);
|
||||
beta_thread_buf(Number<offset>{}) = type_convert<AccDataType>(
|
||||
beta_vector.template AsType<BetaDataType>()[Number<i_row_vec_>{}]);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}(
|
||||
[&](auto I) { acc_thread_buf(I) = type_convert<AccDataType>(0.0f); });
|
||||
|
||||
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
|
||||
mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
var_thread_buf(I) = type_convert<AccDataType>(0.0f);
|
||||
});
|
||||
|
||||
static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) {
|
||||
load_current_sub_row(i_dim_sub, Number<0>{});
|
||||
static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) {
|
||||
load_current_sub_row(i_dim_sub, Number<1>{} + i_row);
|
||||
accumulate_current_sub_row(i_dim_sub, i_row);
|
||||
threadwise_welford_sub_row(i_dim_sub, i_row);
|
||||
});
|
||||
accumulate_current_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
|
||||
threadwise_welford_sub_row(i_dim_sub, Number<RowSubBlocks - 1>{});
|
||||
|
||||
// blockwise welford
|
||||
static_for<0, mean_var_buf_size, 1>{}([&](auto I) {
|
||||
if constexpr(I > 0)
|
||||
block_sync_lds();
|
||||
BlockwiseWelford::Run(
|
||||
mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_);
|
||||
});
|
||||
|
||||
// store
|
||||
static_for<0, RowSubBlocks, 1>{}(
|
||||
[&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); });
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
886
include/ck/utility/amd_buffer_addressing_builtins.hpp
Normal file
886
include/ck/utility/amd_buffer_addressing_builtins.hpp
Normal file
@@ -0,0 +1,886 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
union BufferResource
|
||||
{
|
||||
__device__ constexpr BufferResource() : content{} {}
|
||||
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
int32x4_t content;
|
||||
StaticallyIndexedArray<T*, 2> address;
|
||||
StaticallyIndexedArray<int32_t, 4> range;
|
||||
StaticallyIndexedArray<int32_t, 4> config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
||||
{
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave)
|
||||
{
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave,
|
||||
index_t element_space_size)
|
||||
{
|
||||
// wavewise base address (64 bit)
|
||||
auto p = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
int32_t stride = 0;
|
||||
int32_t num = element_space_size * sizeof(T);
|
||||
auto flags = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T* p_wave)
|
||||
{
|
||||
// wavewise base address (64 bit)
|
||||
auto p = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
int32_t stride = 0;
|
||||
int32_t num = 0xffffffff;
|
||||
auto flags = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
|
||||
}
|
||||
|
||||
// buffer atomic-add fp16
|
||||
__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
|
||||
half2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32");
|
||||
|
||||
// buffer atomic-add i32
|
||||
__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
|
||||
int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
|
||||
float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32");
|
||||
|
||||
// buffer atomic-add fp32
|
||||
__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64(
|
||||
double vdata,
|
||||
int32x4_t rsrc, // dst_wave_buffer_resource
|
||||
int voffset, // dst_thread_addr_offset
|
||||
int soffset, // dst_wave_addr_offset
|
||||
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32");
|
||||
|
||||
// memory coherency bit for buffer store/load instruction
|
||||
// check ISA manual for each GFX target
|
||||
// e.g. for
|
||||
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf,
|
||||
// page 67~68
|
||||
enum struct AmdBufferCoherenceEnum
|
||||
{
|
||||
DefaultCoherence = 0, // default value
|
||||
GLC = 1,
|
||||
SLC = 2,
|
||||
GLC_SLC = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type<int8_t, N>::type
|
||||
amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __builtin_amdgcn_raw_buffer_load_b8(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
|
||||
int16_t tmp = __builtin_amdgcn_raw_buffer_load_b16(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x2_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
int32_t tmp = __builtin_amdgcn_raw_buffer_load_b32(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x4_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x2_t tmp = __builtin_amdgcn_raw_buffer_load_b64(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<int8x8_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
int32x4_t tmp = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
return bit_cast<int8x16_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp1 =
|
||||
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
|
||||
|
||||
return bit_cast<int8x32_t>(tmp);
|
||||
}
|
||||
else if constexpr(N == 64)
|
||||
{
|
||||
int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp1 =
|
||||
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp2 =
|
||||
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
int32x4_t tmp3 =
|
||||
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int32_t),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
vector_type<int32_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) = tmp1;
|
||||
tmp.AsType<int32x4_t>()(Number<2>{}) = tmp2;
|
||||
tmp.AsType<int32x4_t>()(Number<3>{}) = tmp3;
|
||||
|
||||
return bit_cast<int8x64_t>(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type<T, N>::type
|
||||
amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using r_t = typename vector_type<T, N>::type;
|
||||
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
|
||||
return bit_cast<r_t>(raw_data);
|
||||
}
|
||||
|
||||
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ void
|
||||
amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread_data,
|
||||
__amdgpu_buffer_rsrc_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
__builtin_amdgcn_raw_buffer_store_b8(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b16(bit_cast<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
__builtin_amdgcn_raw_buffer_store_b32(bit_cast<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
__builtin_amdgcn_raw_buffer_store_b64(bit_cast<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
__builtin_amdgcn_raw_buffer_store_b128(bit_cast<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 32)
|
||||
{
|
||||
vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 4,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
else if constexpr(N == 64)
|
||||
{
|
||||
vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 4,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 8,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t) * 12,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
__amdgpu_buffer_rsrc_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, fp8_storage_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
|
||||
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
|
||||
|
||||
amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset);
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
T* addr)
|
||||
{
|
||||
static_assert((is_same<T, bhalf_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
vector_type<half_t, N> tmp{src_thread_data};
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
|
||||
tmp.template AsType<half2_t>()[i]);
|
||||
});
|
||||
}
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
else if constexpr(is_same<T, bhalf_t>::value)
|
||||
{
|
||||
vector_type<bhalf_t, N> tmp{src_thread_data};
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
|
||||
tmp.template AsType<bhalf2_t>()[i]);
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<float, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(float),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<float, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(float),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(float),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType<float>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(float),
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<half_t, 4> tmp{src_thread_data};
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType<half2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(half2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType<half2_t>()[i],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + i * sizeof(half2_t),
|
||||
0);
|
||||
});
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<int32_t, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<int32_t, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(int32_t),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(int32_t),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType<int32_t>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(int32_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_atomic_max_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, double>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
if constexpr(is_same<T, double>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
vector_type<double, 2> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(double),
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
vector_type<double, 4> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + sizeof(double),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<2>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 2 * sizeof(double),
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType<double>()[Number<3>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 3 * sizeof(double),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size)
|
||||
{
|
||||
const __amdgpu_buffer_rsrc_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource_new(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
|
||||
#else
|
||||
|
||||
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0)};
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size,
|
||||
T customized_value)
|
||||
{
|
||||
const __amdgpu_buffer_rsrc_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource_new(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0)};
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T,
|
||||
index_t N,
|
||||
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const __amdgpu_buffer_rsrc_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource_new(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_atomic_add requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
if constexpr(is_same<T, bhalf_t>::value)
|
||||
{
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_global_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, p_dst_wave + dst_thread_element_offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_atomic_add_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_atomic_max requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_atomic_max_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Direct loads from global to LDS.
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr,
|
||||
index_t size,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t offset,
|
||||
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32");
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
template <typename T, index_t NumElemsPerThread>
|
||||
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
const index_t global_offset,
|
||||
T* lds_base_ptr,
|
||||
const index_t lds_offset,
|
||||
const bool is_valid,
|
||||
const index_t src_element_space_size)
|
||||
{
|
||||
// Direct loads require that each thread reads and writes exactly a single DWORD.
|
||||
constexpr auto dword_bytes = 4;
|
||||
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
|
||||
static_assert(bytes_per_thread == dword_bytes);
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
const uint32_t* global_ptr =
|
||||
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
|
||||
#else
|
||||
const uint32_t* global_ptr =
|
||||
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
|
||||
#endif
|
||||
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
|
||||
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
|
||||
|
||||
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
|
||||
T* lds_ptr = lds_base_ptr + lds_offset;
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
auto const lds_ptr_sgpr =
|
||||
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
|
||||
#else
|
||||
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
|
||||
#endif
|
||||
asm volatile("s_mov_b32 m0, %0; \n\t"
|
||||
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
|
||||
"v"(global_offset_bytes),
|
||||
"s"(src_resource)
|
||||
: "memory");
|
||||
#else
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
__attribute__((address_space(3))) uint32_t* lds_ptr =
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
|
||||
#else
|
||||
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
|
||||
reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
|
||||
#endif
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
@@ -21,8 +21,7 @@
|
||||
#define CK_USE_OCP_FP8 0
|
||||
#endif
|
||||
|
||||
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
|
||||
defined(__gfx1201__) || defined(__gfx950__)) && \
|
||||
#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \
|
||||
__HIP_DEVICE_COMPILE__
|
||||
#define CK_FP8_CVT_FAST_PATH 1
|
||||
#else
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for MI300 models
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
|
||||
@@ -33,7 +33,11 @@
|
||||
#include "ck/utility/thread_group.hpp"
|
||||
#include "ck/utility/debug.hpp"
|
||||
|
||||
#include "ck/utility/amd_buffer_addressing.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck/utility/amd_wave_read_first_lane.hpp"
|
||||
#include "ck/utility/generic_memory_space_atomic.hpp"
|
||||
#include "ck/utility/get_id.hpp"
|
||||
|
||||
@@ -7,7 +7,11 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "generic_memory_space_atomic.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for MI300 models
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
|
||||
|
||||
@@ -8,7 +8,11 @@
|
||||
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
|
||||
2555
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
2555
include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,11 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
|
||||
defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
|
||||
|
||||
@@ -5,7 +5,11 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#if __clang_major__ >= 20
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp"
|
||||
#else
|
||||
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
@@ -207,6 +207,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
auto a_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
|
||||
@@ -318,10 +319,10 @@ struct FusedMoeGemmPipeline_FlatmmUk
|
||||
{0, 0},
|
||||
dist_);
|
||||
}();
|
||||
|
||||
auto o_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
|
||||
|
||||
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
|
||||
auto w_scale = GetWeightScale(
|
||||
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
|
||||
|
||||
@@ -68,15 +68,19 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
|
||||
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
|
||||
|
||||
// TODO: Should we have two policies? Interwave & Intrawave ??
|
||||
// Controls how many MAC clusters (MFMA blocks) we have per wave
|
||||
// Ie if
|
||||
// InterWaveSchedulingMacClusters = 1;
|
||||
// KPerBlock == 32
|
||||
// WarpGemm::kK = 8
|
||||
// Then we would group all 4 WarpGemms into single MAC cluster.
|
||||
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
|
||||
// split those 4 warp gemms into two groups.
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
// should be at least equal to: WarpGemm::Impl::kABKPerLane
|
||||
// and the question is how to assess upper limit or exact value?
|
||||
// TODO: Should we introduce AK1/BK1 parameters ?
|
||||
static constexpr index_t KPack = 8;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * KPack;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
public:
|
||||
@@ -129,11 +133,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
@@ -153,11 +158,12 @@ struct BlockUniversalGemmAsBsCr
|
||||
{
|
||||
constexpr index_t KPerThread = Traits::KPerThread;
|
||||
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
|
||||
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
|
||||
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
|
||||
constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
|
||||
sequence<KIterInterWave>,
|
||||
sequence<KIterInterwave>,
|
||||
sequence<KIterPerWarp>>;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
@@ -371,11 +377,9 @@ struct BlockUniversalGemmAsBsCr
|
||||
static constexpr index_t KPerThread = GemmTraits::KPerThread;
|
||||
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
|
||||
static constexpr index_t KPerInnerLoop =
|
||||
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
|
||||
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
|
||||
// Would we need InterWaveSchedulingMacClusters > 1 ???
|
||||
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
|
||||
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
|
||||
|
||||
static constexpr auto ALdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};
|
||||
|
||||
@@ -136,7 +136,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
@@ -196,7 +196,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
|
||||
@@ -252,7 +252,7 @@ struct UniversalGemmBasePolicy
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
if(source MATCHES "_xdl")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(source MATCHES "_wmma")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(source MATCHES "mha")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
|
||||
@@ -10,7 +10,7 @@ if [ $# -ge 2 ] ; then
|
||||
shift 2
|
||||
REST_ARGS=$@
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx942"
|
||||
REST_ARGS=
|
||||
fi
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ if [ $# -ge 2 ] ; then
|
||||
shift 2
|
||||
REST_ARGS=$@
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx942"
|
||||
REST_ARGS=
|
||||
fi
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ function(add_test_executable TEST_NAME)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(ARGN MATCHES "_smfmac")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
@@ -196,11 +196,11 @@ function(add_gtest_executable TEST_NAME)
|
||||
if(ARGN MATCHES "_xdl")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx950)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(ARGN MATCHES "_smfmac")
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
elseif(ARGN MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
|
||||
Reference in New Issue
Block a user