mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
tidy
This commit is contained in:
@@ -20,6 +20,12 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
|
||||
return (T*)p;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
return (T CONSTANT*)p;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
|
||||
@@ -11,59 +11,11 @@ namespace ck {
|
||||
template <typename T>
|
||||
__host__ __device__ void print_array(const char* s, T a)
|
||||
{
|
||||
using data_type = decltype(a.At(Number<0>{}));
|
||||
constexpr index_t nsize = a.Size();
|
||||
|
||||
#if 0
|
||||
if constexpr(is_same<data_type, uint32_t>{})
|
||||
{
|
||||
printf("%s size %u, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
else if constexpr(is_same<data_type, int32_t>{})
|
||||
{
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
else if constexpr(is_same<data_type, bool>{})
|
||||
{
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
#else
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void print_array_v2(const char* s, T a)
|
||||
{
|
||||
using data_type = decltype(a.At(Number<0>{}));
|
||||
constexpr index_t nsize = a.Size();
|
||||
|
||||
#if 0
|
||||
if constexpr(is_same<data_type, uint32_t>{})
|
||||
{
|
||||
printf("%s size %u, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
|
||||
printf("}\n");
|
||||
}
|
||||
else if constexpr(is_same<data_type, int32_t>{})
|
||||
{
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
|
||||
printf("}\n");
|
||||
}
|
||||
#else
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
|
||||
printf("}\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -257,9 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
|
||||
@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
|
||||
@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
|
||||
@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
|
||||
@@ -189,7 +189,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -216,7 +215,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -243,7 +241,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -270,7 +267,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -315,14 +311,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
@@ -343,14 +340,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
@@ -371,14 +369,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -399,14 +398,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -185,7 +185,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -212,7 +211,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -239,7 +237,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -266,7 +263,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -311,14 +307,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
@@ -339,14 +338,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
@@ -367,14 +369,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -395,14 +400,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
|
||||
@@ -153,7 +153,6 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -173,20 +172,19 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -142,10 +142,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
@@ -158,9 +156,9 @@ int main(int argc, char* argv[])
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
@@ -173,8 +171,10 @@ int main(int argc, char* argv[])
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
|
||||
@@ -34,24 +34,16 @@ struct KernelTimer
|
||||
using device_stream_t = hipStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
hipStream_t stream_id = nullptr;
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
int nrepeat,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
float launch_and_time_kernel(
|
||||
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel,
|
||||
|
||||
printf("Warm up\n");
|
||||
|
||||
hipStream_t stream_id = nullptr;
|
||||
|
||||
// warm up
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user