diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index aa0ab162fc..503c87e138 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -78,7 +78,7 @@ int main(int argc, char* argv[]) exit(0); } - int group_count = 4; + int group_count = rand() % 16 + 1; // GEMM shape std::vector gemm_shapes; @@ -189,12 +189,17 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); + + // do GEMM auto argument = gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 9bc3cb1a02..1f6319d3f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -42,6 +42,8 @@ struct BaseOperator virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + virtual void SetWorkSpacePointer(BaseArgument*, void*) const {} + virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index 08a70823be..0617b4fcb7 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -24,57 +24,33 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdlops_v2r3( - const StaticallyIndexedArray gemm_descs, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); -#if 1 - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ && - i < group_count) - { - auto group_id = i; - - GridwiseGemm::template Run( - gemm_descs[group_id].a_ptr, - gemm_descs[group_id].b_ptr, - gemm_descs[group_id].c_ptr, - p_shared, - gemm_descs[group_id].a_grid_desc_k0_m_k1_, - gemm_descs[group_id].b_grid_desc_k0_n_k1_, - gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - a_element_op, - b_element_op, - c_element_op, - gemm_descs[group_id].grouped_gemm_block_2_ctile_map_); - } - }); -#else - const auto gemm_desc_ptr = reinterpret_cast(&gemm_descs); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); index_t group_id = 0; - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd && - i < group_count) - ? i - : group_id; - }); - - const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart; + for(index_t i = 0; i < group_count; i++) + { + group_id = + (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) + ? i + : group_id; + } GridwiseGemm::template Run( gemm_desc_ptr[group_id].a_ptr, @@ -87,11 +63,9 @@ __global__ void a_element_op, b_element_op, c_element_op, - gemm_desc_ptr[group_id].block_2_ctile_map_, - block_id_grp); -#endif + gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_); #else - ignore = gemm_descs; + ignore = gemm_descs_const; ignore = group_count; ignore = a_element_op; ignore = b_element_op; @@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl { grid_size_ = 0; + gemm_descs_args_workspace_ = nullptr; + group_count_ = ck::type_convert(gemm_shapes.size()); if(!(group_count_ == ck::type_convert(p_a.size()) && @@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl std::vector gemm_desc_kernel_arg_; + void* gemm_descs_args_workspace_; + index_t grid_size_; }; @@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - StaticallyIndexedArray gemm_desc_kernel_args; - bool has_main_k_block_loop = true; - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - if(i < arg.gemm_desc_kernel_arg_.size()) + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.c_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + + if(!GridwiseGemm::CheckValidity( + arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_)) { - gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; - - std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" - << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " - << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; - - std::cout << ", arg.b_grid_desc_k0_n_k1_{" - << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " - << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; - - std::cout << ", arg.c_grid_desc_m_n_{ " - << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", " - << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}" - << std::endl; - - if(!GridwiseGemm::CheckValidity( - gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_, - gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_, - gemm_desc_kernel_args[i].c_grid_desc_m_n_, - gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); - } - - const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) * - gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2); - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) - { - throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); - } + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - }); + + const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) * + arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + hipGetErrorString( + hipMemcpy(arg.gemm_descs_args_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), + hipMemcpyHostToDevice)); float ave_time = 0; @@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl kernel_grouped_gemm_xdlops_v2r3, + GemmDescKernelArg, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - true, - MaxGroupCount>; + true>; - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - gemm_desc_kernel_args, - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } else { @@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl kernel_grouped_gemm_xdlops_v2r3, + GemmDescKernelArg, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - false, - MaxGroupCount>; + false>; - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - gemm_desc_kernel_args, - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } return ave_time; @@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl return str.str(); } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GemmDescKernelArg); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override + { + dynamic_cast(p_arg)->gemm_descs_args_workspace_ = workspace_ptr; + } }; } // namespace device diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index a97133dca6..fc8ec66b51 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) auto c_element_op = PassThrough{}; // do GEMM - auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get())); + + groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get()); for(std::size_t i = 0; i < gemm_shapes.size(); i++)