mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Add xdlops v4r4r4 into online compilation (#48)
* init for v4r4 xdlops olc * refactor wrap * init impl of v4r4 nchw xdlops olc * tuning * test perf * fixed v4r4 nhwc * tuned v4r4 nhwc * use gridwise_gemm_xdlops_v2r3 * swap a/b * add pointer support into offline v2r3 * debugging v4r4r4 transform for olc * change timer of olc * refactor v4r4 xdlops nchw olc * remove transform fun in v4r4 xdlops nhwc olc Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -1,365 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v1.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename FloatAB,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmMPerWave,
|
||||
index_t GemmNPerWave,
|
||||
index_t GemmKPack,
|
||||
typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmKPack;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmm_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
|
||||
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
|
||||
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
|
||||
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
|
||||
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
|
||||
// tensor hack for NKHW format
|
||||
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
|
||||
in_gemmk0_gemmn_gemmk1_global_desc,
|
||||
out_m0_m1_m2_n_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
|
||||
out_m0_m1_m2_n_global_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
|
||||
}
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename FloatAB,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmMPerWave,
|
||||
index_t GemmNPerWave,
|
||||
index_t GemmKPack,
|
||||
typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmKPack;
|
||||
|
||||
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
|
||||
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmm_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0));
|
||||
assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1));
|
||||
assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0));
|
||||
assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0));
|
||||
|
||||
assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0);
|
||||
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, GemmMPerWave, GemmNPerWave, GemmKPack>{};
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor(
|
||||
out_gemmm_gemmn_global_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
|
||||
|
||||
// out_gemm_block_cluster_desc
|
||||
const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2(
|
||||
make_tuple(GemmM / Number<GemmMPerBlock>{}, GemmN / Number<GemmNPerBlock>{}));
|
||||
|
||||
// hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 1, 2, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
|
||||
// tensor hack for NKHW format
|
||||
constexpr auto out_m0_m1_m2_n_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc,
|
||||
in_gemmk0_gemmn_gemmk1_global_desc,
|
||||
out_m0_m1_m2_n_global_desc,
|
||||
out_gemm_block_cluster_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_global_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_global_iterator_hacks,
|
||||
out_m0_m1_m2_n_global_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,384 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPack,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
__host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
|
||||
const FloatAB* p_b_global,
|
||||
FloatC* p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm =
|
||||
GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGlobalDesc,
|
||||
BGlobalDesc,
|
||||
CGlobalDesc,
|
||||
CBlockClusterDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << std::endl;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
|
||||
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
|
||||
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
|
||||
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
|
||||
|
||||
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
|
||||
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
|
||||
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
|
||||
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,202 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPack,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K_M,
|
||||
typename ABlockTransferThreadClusterLengths_K_M,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N,
|
||||
typename BBlockTransferThreadClusterLengths_K_N,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
__host__ float launch_kernel_dynamic_gemm_xdlops_v2(const FloatAB* p_a_global,
|
||||
const FloatAB* p_b_global,
|
||||
FloatC* p_c_global,
|
||||
const AGlobalDesc& a_k_m_global_desc,
|
||||
const BGlobalDesc& b_k_n_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto M = a_k_m_global_desc.GetLength(I1);
|
||||
const auto N = b_k_n_global_desc.GetLength(I1);
|
||||
const auto K = a_k_m_global_desc.GetLength(I0);
|
||||
|
||||
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm =
|
||||
GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGlobalDesc,
|
||||
BGlobalDesc,
|
||||
CGlobalDesc,
|
||||
CBlockClusterDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K_M,
|
||||
ABlockTransferThreadClusterLengths_K_M,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N,
|
||||
BBlockTransferThreadClusterLengths_K_N,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGlobalIteratorHacks,
|
||||
BGlobalIteratorHacks,
|
||||
CGlobalIteratorHacks,
|
||||
AGlobalMoveSliceWindowIteratorHacks,
|
||||
BGlobalMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v2<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k_m_global_desc,
|
||||
b_k_n_global_desc,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
c_block_cluster_desc);
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
|
||||
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
|
||||
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
|
||||
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
|
||||
|
||||
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
|
||||
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
|
||||
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
|
||||
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v1<gridwise_gemm,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGlobalDesc>,
|
||||
remove_reference_t<BGlobalDesc>,
|
||||
remove_reference_t<CGlobalDesc>,
|
||||
remove_reference_t<CBlockClusterDesc>>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,167 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_xdlops_v2r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v2r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -21,6 +21,7 @@ template <index_t BlockSize,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
@@ -83,6 +84,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
@@ -148,6 +150,7 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
@@ -162,6 +165,32 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -13,7 +13,7 @@ template <index_t BlockSize,
|
||||
class BBlockDesc,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPack>
|
||||
index_t K1>
|
||||
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
{
|
||||
|
||||
@@ -32,7 +32,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, KPack>{};
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
|
||||
static constexpr index_t MWaves = M1 / MPerWave;
|
||||
static constexpr index_t NWaves = N1 / NPerWave;
|
||||
@@ -119,18 +119,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
|
||||
"wrong! KPack dimension not consistent");
|
||||
"wrong! K1 dimension not consistent");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!");
|
||||
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
|
||||
|
||||
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!");
|
||||
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
@@ -194,11 +194,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{}));
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{}));
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
@@ -207,20 +207,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, MRepeat, 1, KPack>,
|
||||
Sequence<1, MRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
KPack,
|
||||
K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, NRepeat, 1, KPack>,
|
||||
Sequence<1, NRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
KPack,
|
||||
K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
@@ -233,7 +233,7 @@ template <index_t BlockSize,
|
||||
class BBlockDesc,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPack>
|
||||
index_t K1>
|
||||
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
{
|
||||
|
||||
@@ -244,7 +244,7 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, KPack>{};
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
@@ -339,18 +339,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
|
||||
"wrong! KPack dimension not consistent");
|
||||
"wrong! K1 dimension not consistent");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!");
|
||||
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
|
||||
|
||||
static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!");
|
||||
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
@@ -491,11 +491,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<KPack>{}));
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<KPack>{}));
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
@@ -504,20 +504,20 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // KPack,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // KPack,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
|
||||
@@ -1,585 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc a_k0_m_k1_global_desc,
|
||||
const BGlobalDesc b_k0_n_k1_global_desc,
|
||||
const CGlobalDesc c_m0_m1_m2_n_global_desc,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
{
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k0_m_k1_global_desc,
|
||||
b_k0_n_k1_global_desc,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k0_m_k1_global_desc =
|
||||
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
|
||||
const auto b_k0_n_k1_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
|
||||
const auto c_m0_m1_m2_n_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
|
||||
|
||||
const auto c_block_cluster_desc =
|
||||
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
|
||||
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k0_m_k1_global_desc,
|
||||
b_k0_n_k1_global_desc,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPack,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K_M_KPack,
|
||||
typename ABlockTransferThreadClusterLengths_K_M_KPack,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_KPack,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N_KPack,
|
||||
typename BBlockTransferThreadClusterLengths_K_N_KPack,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_KPack,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = Number<KPack>{};
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k0_m_k1_global_desc,
|
||||
const BGlobalDesc& b_k0_n_k1_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
|
||||
const auto K1 = b_k0_n_k1_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_global into SGPR
|
||||
const index_t m_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = Number<KPack>{};
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, MPerBlock, KPack>,
|
||||
ABlockTransferThreadSliceLengths_K_M_KPack,
|
||||
ABlockTransferThreadClusterLengths_K_M_KPack,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_global_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_KPack,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_k0_m_k1_global_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_global, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, NPerBlock, KPack>,
|
||||
BBlockTransferThreadSliceLengths_K_N_KPack,
|
||||
BBlockTransferThreadClusterLengths_K_N_KPack,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_global_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_KPack,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_k0_n_k1_global_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_global, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPack>{};
|
||||
|
||||
constexpr auto CLayout = blockwise_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t BlkSize = CLayout.GetBlkSize();
|
||||
constexpr index_t NumBlks = CLayout.GetNumBlks();
|
||||
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
|
||||
|
||||
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
|
||||
|
||||
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr,
|
||||
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
|
||||
c_mr_nr_nx_desc.GetElementSpaceSize()>
|
||||
c_thread_buf;
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
// decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
|
||||
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
|
||||
AGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m_k1_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n_k1_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
|
||||
|
||||
asm volatile("s_nop 0");
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m_k1_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n_k1_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
|
||||
|
||||
asm volatile("s_nop 0");
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_global_move_slice_window_iterator_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
|
||||
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
|
||||
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
|
||||
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
|
||||
make_tuple(mr_i, nr_i, xdlops_i))>{}];
|
||||
|
||||
static_for<0, BlkSize, 1>{}([&](auto j) {
|
||||
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
|
||||
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
|
||||
});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(
|
||||
mr_i, nr_i, xdlops_i, blk_i);
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
|
||||
CGlobalIteratorHacks{};
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_global_desc),
|
||||
Sequence<M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m1_m2_n_global_desc,
|
||||
make_multi_index(m_thread_data_on_global / (M2 * M1),
|
||||
m_thread_data_on_global % (M2 * M1) / M2,
|
||||
m_thread_data_on_global % M2,
|
||||
n_thread_data_on_global)}
|
||||
.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_global_buf,
|
||||
c_m0_m1_m2_n_global_tensor_iterator_hacks);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k0_m_k1_global_desc,
|
||||
const BGlobalDesc& b_k0_n_k1_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k0_m_k1_global_desc,
|
||||
b_k0_n_k1_global_desc,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_block_cluster_desc,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,498 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc a_k0_m_k1_global_desc,
|
||||
const BGlobalDesc b_k0_n_k1_global_desc,
|
||||
const CGlobalDesc c_m0_m1_m2_n_global_desc,
|
||||
const CBlockClusterDesc c_block_cluster_desc)
|
||||
{
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k0_m_k1_global_desc,
|
||||
b_k0_n_k1_global_desc,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_block_cluster_desc);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global,
|
||||
const FloatB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_global_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_global_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_desc)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k0_m_k1_global_desc =
|
||||
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k0_m_k1_global_desc);
|
||||
const auto b_k0_n_k1_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k0_n_k1_global_desc);
|
||||
const auto c_m0_m1_m2_n_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_m2_n_global_desc);
|
||||
|
||||
const auto c_block_cluster_desc =
|
||||
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
|
||||
|
||||
GridwiseGemm::Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k0_m_k1_global_desc,
|
||||
b_k0_n_k1_global_desc,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_block_cluster_desc,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
typename CBlockClusterDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPack,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K_M_KPack,
|
||||
typename ABlockTransferThreadClusterLengths_K_M_KPack,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_KPack,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N_KPack,
|
||||
typename BBlockTransferThreadClusterLengths_K_N_KPack,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_KPack,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = Number<KPack>{};
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k0_m_k1_global_desc,
|
||||
const BGlobalDesc& b_k0_n_k1_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc,
|
||||
FloatAB* __restrict__ p_shared_block)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_global_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_global_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_global_desc.GetLength(I1);
|
||||
const auto K1 = b_k0_n_k1_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_global into SGPR
|
||||
const index_t m_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_global =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = Number<KPack>{};
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, Number<KPack>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, MPerBlock, KPack>,
|
||||
ABlockTransferThreadSliceLengths_K_M_KPack,
|
||||
ABlockTransferThreadClusterLengths_K_M_KPack,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_global_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_KPack,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_k0_m_k1_global_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_global, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, NPerBlock, KPack>,
|
||||
BBlockTransferThreadSliceLengths_K_N_KPack,
|
||||
BBlockTransferThreadClusterLengths_K_N_KPack,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_global_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_KPack,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_k0_n_k1_global_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_global, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
|
||||
make_pass_through_transform(Number<KPack>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPack>{};
|
||||
|
||||
constexpr auto CLayout = blockwise_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t BlkSize = CLayout.GetBlkSize();
|
||||
constexpr index_t NumBlks = CLayout.GetNumBlks();
|
||||
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
|
||||
|
||||
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
|
||||
|
||||
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr,
|
||||
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
|
||||
c_mr_nr_nx_desc.GetElementSpaceSize()>
|
||||
c_thread_buf;
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
// decltype(c_m0_m1_n0_n1_thread_desc),
|
||||
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
|
||||
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack =
|
||||
AGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_global_move_slice_window_iterator_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
|
||||
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
|
||||
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
|
||||
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
|
||||
make_tuple(mr_i, nr_i, xdlops_i))>{}];
|
||||
|
||||
static_for<0, BlkSize, 1>{}([&](auto j) {
|
||||
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
|
||||
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
|
||||
});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(
|
||||
mr_i, nr_i, xdlops_i, blk_i);
|
||||
|
||||
const index_t m_thread_data_on_global =
|
||||
m_block_data_idx_on_global + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_global =
|
||||
n_block_data_idx_on_global + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks =
|
||||
CGlobalIteratorHacks{};
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_global_desc),
|
||||
Sequence<M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m1_m2_n_global_desc,
|
||||
make_multi_index(m_thread_data_on_global / (M2 * M1),
|
||||
m_thread_data_on_global % (M2 * M1) / M2,
|
||||
m_thread_data_on_global % M2,
|
||||
n_thread_data_on_global)}
|
||||
.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_global_buf,
|
||||
c_m0_m1_m2_n_global_tensor_iterator_hacks);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
const AGlobalDesc& a_k0_m_k1_global_desc,
|
||||
const BGlobalDesc& b_k0_n_k1_global_desc,
|
||||
const CGlobalDesc& c_m0_m1_m2_n_global_desc,
|
||||
const CBlockClusterDesc& c_block_cluster_desc)
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(p_a_global,
|
||||
p_b_global,
|
||||
p_c_global,
|
||||
a_k0_m_k1_global_desc,
|
||||
b_k0_n_k1_global_desc,
|
||||
c_m0_m1_m2_n_global_desc,
|
||||
c_block_cluster_desc,
|
||||
p_shared_block);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,509 +0,0 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0M1M2NGridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v2r2(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
// TODO: turn on this
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
|
||||
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{};
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
|
||||
constexpr auto M0 = Number<CLayout.M1()>{};
|
||||
constexpr auto M1 = Number<CLayout.N1()>{};
|
||||
constexpr auto M2 = Number<CLayout.M0()>{};
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (M1 * M2), M1, M2)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return c_m0_m1_m2_n_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_k0_m_k1_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_k0_n_k1_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1.value>{};
|
||||
|
||||
constexpr auto CLayout = blockwise_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t BlkSize = CLayout.GetBlkSize();
|
||||
constexpr index_t NumBlks = CLayout.GetNumBlks();
|
||||
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
|
||||
|
||||
constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, Number<NumXdlops>{}));
|
||||
|
||||
constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<NumBlks>{}, Number<BlkSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr,
|
||||
vector_type<FloatAcc, c_blk_nb_bs_desc.GetElementSpaceSize()>,
|
||||
c_mr_nr_nx_desc.GetElementSpaceSize()>
|
||||
c_thread_buf;
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_iterator_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
|
||||
static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) {
|
||||
static_for<0, NumBlks, 1>{}([&](auto blk_i) {
|
||||
auto c_blk = c_thread_buf[Number<c_mr_nr_nx_desc.CalculateOffset(
|
||||
make_tuple(mr_i, nr_i, xdlops_i))>{}];
|
||||
|
||||
static_for<0, BlkSize, 1>{}([&](auto j) {
|
||||
c_blk_buf_(j) = c_blk.template AsType<FloatAcc>()[Number<
|
||||
c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}];
|
||||
});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(
|
||||
mr_i, nr_i, xdlops_i, blk_i);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks =
|
||||
CGridIteratorHacks{};
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m1_m2_n_grid_desc,
|
||||
make_multi_index(m_thread_data_on_grid / (M2 * M1),
|
||||
m_thread_data_on_grid % (M2 * M1) / M2,
|
||||
m_thread_data_on_grid % M2,
|
||||
n_thread_data_on_grid)}
|
||||
.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
@@ -45,6 +46,50 @@ __global__ void
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0M1M2NGridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockClusterAdaptor*>((const void*)p_c_block_cluster_adaptor);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
@@ -59,6 +104,7 @@ template <index_t BlockSize,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
@@ -94,7 +140,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -160,7 +206,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1.value>{};
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
|
||||
@@ -267,7 +313,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1.value>,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -294,7 +340,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1.value>,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
@@ -354,7 +400,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1.value>{};
|
||||
K1>{};
|
||||
|
||||
constexpr auto CLayout = blockwise_gemm.GetCLayout();
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@
|
||||
|
||||
#if CK_USE_AMD_XDLOPS
|
||||
#include "amd_xdlops.hpp"
|
||||
#include "amd_xdlops_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,359 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -0,0 +1,359 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||
int n,
|
||||
int hi,
|
||||
int wi,
|
||||
int c,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c));
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c));
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
|
||||
constexpr auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
Reference in New Issue
Block a user