diff --git a/example/20_cgemm/cgemm_xdl_bf16.cpp b/example/20_cgemm/cgemm_xdl_bf16.cpp index 836a3c13dc..807d2e35ed 100644 --- a/example/20_cgemm/cgemm_xdl_bf16.cpp +++ b/example/20_cgemm/cgemm_xdl_bf16.cpp @@ -151,6 +151,7 @@ int main(int argc, char* argv[]) Tensor c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor aux(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor aux_2(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; @@ -159,6 +160,7 @@ int main(int argc, char* argv[]) std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl; std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl; std::cout << "aux: " << aux.mDesc << std::endl; + std::cout << "aux_2: " << aux_2.mDesc << std::endl; switch(init_method) { @@ -185,6 +187,7 @@ int main(int argc, char* argv[]) DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * c_m_n_imag_device_result.mDesc.GetElementSpace()); DeviceMem aux_device_buf(sizeof(CDataType) * aux.mDesc.GetElementSpace()); + DeviceMem aux_2_device_buf(sizeof(CDataType) * aux_2.mDesc.GetElementSpace()); a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); @@ -206,6 +209,7 @@ int main(int argc, char* argv[]) static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast(aux_device_buf.GetDeviceBuffer()), + static_cast(aux_2_device_buf.GetDeviceBuffer()), M, N, K, diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index 8ff8d2d432..5695fe628a 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -20,6 +20,7 @@ struct DeviceCGemm : public BaseOperator void* p_c_real, void* p_c_imag, void* p_aux, + void* p_aux_2, ck::index_t M, ck::index_t N, ck::index_t K, diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp index 1f6ebc7042..9ed8311315 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -390,6 +390,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CDataType* p_c_grid_real, CDataType* p_c_grid_imag, CDataType* p_aux_grid, + CDataType* p_aux_2_grid, index_t MRaw, index_t NRaw, index_t KRaw, @@ -406,6 +407,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle p_c_grid_real_{p_c_grid_real}, p_c_grid_imag_{p_c_grid_imag}, p_aux_grid_{p_aux_grid}, + p_aux_2_grid_{p_aux_2_grid}, a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, @@ -434,6 +436,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CDataType* p_c_grid_real_; CDataType* p_c_grid_imag_; CDataType* p_aux_grid_; + CDataType* p_aux_2_grid_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; CGridDesc_M_N c_grid_desc_m_n_; @@ -488,7 +491,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_real_, arg.p_b_grid_real_, - arg.p_c_grid_real_, + arg.p_aux_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -505,7 +508,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_imag_, arg.p_b_grid_imag_, - arg.p_aux_grid_, + arg.p_aux_2_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -514,7 +517,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // c_real = c_real - aux needed here!!! + // c_real = aux - aux_2 needed here!!! ave_time += launch_and_time_kernel(stream_config, @@ -524,7 +527,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_real_, arg.p_b_grid_imag_, - arg.p_c_grid_imag_, + arg.p_aux_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -541,7 +544,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_imag_, arg.p_b_grid_real_, - arg.p_aux_grid_, + arg.p_aux_2_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -550,7 +553,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // c_imag = c_imag + aux needed here!!! + // c_imag = aux + aux_2 needed here!!! } else { @@ -575,7 +578,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_real_, arg.p_b_grid_real_, - arg.p_c_grid_real_, + arg.p_aux_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -592,7 +595,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_imag_, arg.p_b_grid_imag_, - arg.p_aux_grid_, + arg.p_aux_2_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -601,7 +604,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // // c_real = c_real - aux needed here!!! + // // c_real = aux - aux_2 needed here!!! ave_time += launch_and_time_kernel(stream_config, @@ -611,7 +614,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_real_, arg.p_b_grid_imag_, - arg.p_c_grid_imag_, + arg.p_aux_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -628,7 +631,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle 0, arg.p_a_grid_imag_, arg.p_b_grid_real_, - arg.p_aux_grid_, + arg.p_aux_2_grid_, arg.a_element_op_, arg.b_element_op_, arg.c_element_op_, @@ -637,7 +640,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_); - // c_imag = c_imag + aux needed here!!! + // c_imag = aux + aux_2 needed here!!! } return ave_time; @@ -676,6 +679,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CDataType* p_c_real, CDataType* p_c_imag, CDataType* p_aux, + CDataType* p_aux_2, index_t MRaw, index_t NRaw, index_t KRaw, @@ -693,6 +697,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle p_c_real, p_c_imag, p_aux, + p_aux_2, MRaw, NRaw, KRaw, @@ -714,6 +719,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle void* p_c_real, void* p_c_imag, void* p_aux, + void* p_aux_2, index_t MRaw, index_t NRaw, index_t KRaw, @@ -732,6 +738,7 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static_cast(p_c_real), static_cast(p_c_imag), static_cast(p_aux), + static_cast(p_aux_2), MRaw, NRaw, KRaw, diff --git a/test/cgemm/cgemm_util.hpp b/test/cgemm/cgemm_util.hpp index f45405b275..1a7439e075 100644 --- a/test/cgemm/cgemm_util.hpp +++ b/test/cgemm/cgemm_util.hpp @@ -73,6 +73,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, Tensor& C_real, Tensor& C_imag, Tensor& Aux, + Tensor& Aux_2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -84,6 +85,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * C_real.mDesc.GetElementSpace()); DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * C_imag.mDesc.GetElementSpace()); DeviceMem aux_device_buf(sizeof(CDataType) * Aux.mDesc.GetElementSpace()); + DeviceMem aux_2_device_buf(sizeof(CDataType) * Aux_2.mDesc.GetElementSpace()); a_m_k_real_device_buf.ToDevice(A_real.mData.data()); a_m_k_imag_device_buf.ToDevice(A_imag.mData.data()); @@ -99,6 +101,7 @@ void RunDeviceCGEMM(DeviceCGemmPtr_& cgemmPtr, static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), static_cast(aux_device_buf.GetDeviceBuffer()), + static_cast(aux_2_device_buf.GetDeviceBuffer()), params.M, params.N, params.K, @@ -167,6 +170,8 @@ struct TestCGemm f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor aux( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor aux_2( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); auto f_generate_tensor_value = [](auto& tensor, auto type) { using dataType = decltype(type); @@ -187,7 +192,8 @@ struct TestCGemm c_m_n_imag_host_result, c_m_n_real_device_result, c_m_n_imag_device_result, - aux); + aux, + aux_2); } auto operator()(DeviceCGemmPtr_& cgemmPtr) @@ -216,6 +222,7 @@ struct TestCGemm Tensor& c_device_real = std::get<6>(host_tensors); Tensor& c_device_imag = std::get<7>(host_tensors); Tensor& aux = std::get<8>(host_tensors); + Tensor& aux_2 = std::get<9>(host_tensors); auto a_element_op = AElementwiseOperation{}; auto b_element_op = BElementwiseOperation{}; @@ -248,6 +255,7 @@ struct TestCGemm c_device_real, c_device_imag, aux, + aux_2, a_element_op, b_element_op, c_element_op); @@ -319,6 +327,8 @@ struct TestCGemmBF16 f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor aux_bf16( f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor aux_2_bf16( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); Tensor a_m_k_real_fp32( f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); @@ -354,6 +364,7 @@ struct TestCGemmBF16 c_m_n_real_device_bf16, c_m_n_imag_device_bf16, aux_bf16, + aux_2_bf16, a_m_k_real_fp32, a_m_k_imag_fp32, b_k_n_real_fp32, @@ -383,14 +394,15 @@ struct TestCGemmBF16 Tensor& c_real_device_bf16 = std::get<4>(host_tensors); Tensor& c_imag_device_bf16 = std::get<5>(host_tensors); Tensor& aux_bf16 = std::get<6>(host_tensors); - Tensor& a_real_fp32 = std::get<7>(host_tensors); - Tensor& a_imag_fp32 = std::get<8>(host_tensors); - Tensor& b_real_fp32 = std::get<9>(host_tensors); - Tensor& b_imag_fp32 = std::get<10>(host_tensors); - Tensor& c_real_host_fp32 = std::get<11>(host_tensors); - Tensor& c_imag_host_fp32 = std::get<12>(host_tensors); - Tensor& c_real_device_fp32 = std::get<13>(host_tensors); - Tensor& c_imag_device_fp32 = std::get<14>(host_tensors); + Tensor& aux_2_bf16 = std::get<7>(host_tensors); + Tensor& a_real_fp32 = std::get<8>(host_tensors); + Tensor& a_imag_fp32 = std::get<9>(host_tensors); + Tensor& b_real_fp32 = std::get<10>(host_tensors); + Tensor& b_imag_fp32 = std::get<11>(host_tensors); + Tensor& c_real_host_fp32 = std::get<12>(host_tensors); + Tensor& c_imag_host_fp32 = std::get<13>(host_tensors); + Tensor& c_real_device_fp32 = std::get<14>(host_tensors); + Tensor& c_imag_device_fp32 = std::get<15>(host_tensors); auto a_element_op = AElementwiseOperation{}; auto b_element_op = BElementwiseOperation{}; @@ -424,6 +436,7 @@ struct TestCGemmBF16 c_real_device_bf16, c_imag_device_bf16, aux_bf16, + aux_2_bf16, a_element_op, b_element_op, c_element_op);