mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:12:19 +00:00
fix clang format
This commit is contained in:
@@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
223
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
Executable file → Normal file
223
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
Executable file → Normal file
@@ -127,44 +127,47 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
case 0: break;
|
||||
case 1:
|
||||
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
|
||||
default:
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
default:
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
|
||||
break;
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_re(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
// Intermediate Value For E Real and Img
|
||||
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem e_device_buf_re1(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img1(sizeof(EDataType) *
|
||||
e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
|
||||
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
|
||||
@@ -181,7 +184,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
// set zero for intermediate values
|
||||
e_device_buf_re1.SetZero();
|
||||
e_device_buf_img1.SetZero();
|
||||
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
@@ -189,23 +192,24 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
// device operation
|
||||
// For real Intermediate Value re_1
|
||||
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
|
||||
e_device_buf_re1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument_re1 =
|
||||
op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
|
||||
e_device_buf_re1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_re1))
|
||||
{
|
||||
@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
alpha = -1.f;
|
||||
beta = 1.f;
|
||||
|
||||
@@ -228,21 +231,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
// For real Intermediate Value re_2
|
||||
// auto op = DeviceOpInstance{};
|
||||
// auto invoker = op.MakeInvoker();
|
||||
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
|
||||
e_device_buf_re.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
auto argument_re2 =
|
||||
op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
|
||||
e_device_buf_re.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_re2))
|
||||
{
|
||||
@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
@@ -261,22 +264,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
b_element_op = BElementOp{};
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
|
||||
e_device_buf_img1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
auto argument_img1 =
|
||||
op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
|
||||
e_device_buf_img1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_img1))
|
||||
{
|
||||
@@ -290,23 +293,22 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
|
||||
e_device_buf_img.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
|
||||
auto argument_img2 =
|
||||
op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
|
||||
e_device_buf_img.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_img2))
|
||||
{
|
||||
@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
ck::index_t M =
|
||||
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
@@ -331,9 +332,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
|
||||
|
||||
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ;
|
||||
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
@@ -343,7 +344,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
|
||||
|
||||
auto isRealOk = 0;
|
||||
auto isImgOk = 0;
|
||||
auto isImgOk = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -366,17 +367,16 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument_re =
|
||||
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
|
||||
auto ref_argument_re = ref_op.MakeArgument(
|
||||
a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_re);
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
|
||||
@@ -395,11 +395,11 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
|
||||
alpha = 1.f;
|
||||
beta = -1.f;
|
||||
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
auto ref_argument_re1 =
|
||||
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
|
||||
auto ref_argument_re1 = ref_op.MakeArgument(
|
||||
a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_re1);
|
||||
|
||||
@@ -419,23 +419,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
|
||||
|
||||
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
// Img Part Verification
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
auto ref_argument_img =
|
||||
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
|
||||
|
||||
auto ref_argument_img = ref_op.MakeArgument(
|
||||
a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_img);
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
|
||||
@@ -454,9 +451,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_argument_img1 =
|
||||
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
|
||||
|
||||
auto ref_argument_img1 = ref_op.MakeArgument(
|
||||
a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_img1);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
|
||||
@@ -475,7 +472,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
return (isRealOk && isImgOk);
|
||||
}
|
||||
|
||||
@@ -42,27 +42,27 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const AccElementwiseOperation acc_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
|
||||
@@ -37,22 +37,22 @@ template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2ETileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
|
||||
@@ -124,9 +124,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
|
||||
return *reinterpret_cast<int64_t*>(to_obj);
|
||||
}
|
||||
|
||||
template <
|
||||
typename Object,
|
||||
typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
|
||||
template <typename Object,
|
||||
typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
|
||||
__device__ auto amd_wave_read_first_lane(const Object& obj)
|
||||
{
|
||||
using Size = unsigned;
|
||||
|
||||
@@ -43,15 +43,15 @@ template <typename T,
|
||||
ck::enable_if_t<!(ck::is_same<float, T>{} || ck::is_same<half_t, T>{}), bool> = false>
|
||||
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
|
||||
{
|
||||
#ifdef __HIPCC_RTC__
|
||||
#ifdef __HIPCC_RTC__
|
||||
static_cast<void>(id);
|
||||
static_cast<void>(val);
|
||||
static_cast<void>(seed);
|
||||
#else
|
||||
#else
|
||||
std::ignore = id;
|
||||
std::ignore = val;
|
||||
std::ignore = seed;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user