mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
Navi21 gemm (#197)
* start adding navi21 GEMM * navi_gemm_km_kn_mn_fp32 compiles and passes one test. * rename variables and functions in gridwise_gemm_dlops_v1r3 * add other 3 layouts; format instance * adding more tuning parameters add tuning parameters for other 3 layouts * add gemm_dlops_f16 * tmp * add dependence of DeviceGemm::IsSupportedArg() on arch * minor changes * minor changes * minor changes * minor changes * minor changes * minor changes * minor changes * push gemm_dlops into profiler * minor changes * if using xdl or dlops is moved into profiler_gemm_impl * minor changes * minor changes * remove is_xdl from profile_gemm_impl * make IsSupportedArg dependent on arch for other device_gemm * minor changes * minor changes * fix a bug in f_generate_tensor_value * add 64x64x64 for gemm_dlops_int8 * add 64x64x64 for gemm_dlops_int8 * comment out 3 layouts in gemm_dlops_int8; add 32x32x32 for gemm_dlops_int8; init A values to 1 * fix * start fixing tuning parameters * monir * minor changes * minor changes * minor changes * fixing * adding example * adding example * adding example * add gemm fp32 example * clean up * use 128x128x16 as MNK tile in navi21 gemm example * bug fix * fix test * use new block c tile * clean * fix build Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: shaojiewang <wsjmessi@163.com>
This commit is contained in:
@@ -1,38 +1,38 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
#pragma once
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "blockwise_gemm_dl_v2r3.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0M0M1K1GridDesc,
|
||||
typename BK0N0N1K1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
typename AGridDesc_K0_M0_M1_K1,
|
||||
typename BGridDesc_K0_N0_N1_K1,
|
||||
typename CGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
typename Block2CTileMap,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor)
|
||||
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
|
||||
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -43,10 +43,10 @@ __global__ void
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
a_grid_desc_k0_m0_m1_k1,
|
||||
b_grid_desc_k0_n0_n1_k1,
|
||||
c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
block_2_ctile_map,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
@@ -56,12 +56,12 @@ template <index_t BlockSize,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
typename AGridDesc_K0_M_K1,
|
||||
typename BGridDesc_K0_N_K1,
|
||||
typename CGridDesc_M_N,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
@@ -83,13 +83,8 @@ template <index_t BlockSize,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct GridwiseGemmDl_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
|
||||
static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
||||
return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
|
||||
K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
|
||||
const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
|
||||
MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1)
|
||||
{
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M1 = Number<MPerBlock>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
|
||||
const auto a_grid_desc_k0_m0_m1_k1 =
|
||||
transform_tensor_descriptor(a_grid_desc_k0_m_k1,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_grid_desc;
|
||||
return a_grid_desc_k0_m0_m1_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
|
||||
MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1)
|
||||
{
|
||||
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N1 = Number<NPerBlock>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
|
||||
const auto b_grid_desc_k0_n0_n1_k1 =
|
||||
transform_tensor_descriptor(b_grid_desc_k0_n_k1,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_grid_desc;
|
||||
return b_grid_desc_k0_n0_n1_k1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
@@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
return c_grid_desc_m0_m10_m11_n0_n10_n11;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto cblockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return cblockid_to_m0_n0_block_cluster_adaptor;
|
||||
return BlockToCTileMap_M00_N00_M01_N01<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
|
||||
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{}));
|
||||
using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{}));
|
||||
using CGridDesc_M0_M10_M11_N0_N10_N11 =
|
||||
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
|
||||
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
@@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor,
|
||||
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
|
||||
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
|
||||
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
|
||||
p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
|
||||
p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
make_tuple(im0, in0),
|
||||
make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0),
|
||||
c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, I1, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
||||
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
|
||||
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() ==
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
|
||||
"wrong!");
|
||||
|
||||
@@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
Sequence<K0PerBlock, 1, MPerBlock, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_grid_desc),
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
|
||||
decltype(a_block_desc_k0_m0_m1_k1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
|
||||
@@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_k0_m0_m1_k1_grid_desc,
|
||||
true>(a_grid_desc_k0_m0_m1_k1,
|
||||
make_multi_index(0, im0, 0, 0),
|
||||
a_k0_m0_m1_k1_block_desc,
|
||||
a_block_desc_k0_m0_m1_k1,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
Sequence<K0PerBlock, 1, NPerBlock, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n0_n1_k1_grid_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
|
||||
decltype(b_block_desc_k0_n0_n1_k1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
|
||||
@@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_k0_n0_n1_k1_grid_desc,
|
||||
true>(b_grid_desc_k0_n0_n1_k1,
|
||||
make_multi_index(0, in0, 0, 0),
|
||||
b_k0_n0_n1_k1_block_desc,
|
||||
b_block_desc_k0_n0_n1_k1,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// a_mtx[K0PerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
@@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
|
||||
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0);
|
||||
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
@@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
|
||||
a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
|
||||
b_block_slice_copy_step);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
|
||||
a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1,
|
||||
b_block_slice_copy_step);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * KPerBlock);
|
||||
k_block_data_begin += 2 * K0PerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step);
|
||||
|
||||
__syncthreads();
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
block_sync_lds();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
@@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
@@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
decltype(c_thread_desc_m0_m10_m11_n0_n10_n11),
|
||||
decltype(c_grid_desc_m0_m10_m11_n0_n10_n11),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
@@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
true>{c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}}
|
||||
.Run(c_thread_desc_m0_m10_m11_n0_n10_n11,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
c_grid_desc_m0_m10_m11_n0_n10_n11,
|
||||
c_grid_buf);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
Reference in New Issue
Block a user