mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
adding thread group
This commit is contained in:
@@ -26,17 +26,14 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// buffer resourse, wave size
|
||||
// buffer resourse
|
||||
#ifndef __HIP_DEVICE_COMPILE__ // for host code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
|
||||
#define CK_GPU_WAVE_SIZE -1
|
||||
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) // for GPU code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#define CK_GPU_WAVE_SIZE 64
|
||||
#elif defined(__gfx1030__) // for GPU code
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#define CK_GPU_WAVE_SIZE 32
|
||||
#endif
|
||||
|
||||
// FMA instruction
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
@@ -8,7 +6,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
template <typename ThreadGroup,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
@@ -25,7 +23,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
@@ -53,7 +51,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t thread_id = ThreadGroup::GetThreadId();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
@@ -120,8 +118,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
static_assert(ThreadGroup::GetNumOfThread() == MWaves * NWaves * WaveSize,
|
||||
"ThreadGroup::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
@@ -337,4 +335,3 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
|
||||
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
@@ -169,4 +167,3 @@ struct BlockwiseTensorSliceTransfer_v4r1
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "transpose_vectors.hpp"
|
||||
#include "inner_product.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "thread_group.hpp"
|
||||
#include "debug.hpp"
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
|
||||
@@ -3,11 +3,14 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ constexpr index_t get_wave_size() { return CK_GPU_WAVE_SIZE; }
|
||||
__host__ __device__ constexpr index_t get_warp_size()
|
||||
{ // warpSize is defined by HIP
|
||||
return warpSize;
|
||||
}
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ index_t get_wave_local_1d_id() { return threadIdx.x / get_wave_size(); }
|
||||
__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); }
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user