mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
MIGraphX hipRTC fix (#1923)
* fixed hiprtc compilation issues from new additions, removed clashing mixed precision functionality from codegen(ignore the whole file) * fixed device op error: misplaced header guard * restrict virtual function use in device_gemm_multiple_d file for codegen hiprtc compilation * add CK_CODE_GEN_RTC flag for compilation, since this flag has wider coverage for hiprtc compilation * fixed conditional error in amd_ck_fp8.hpp * Add MaskOutUpperTriangle as a problem parameter to BatchedGemmSoftmaxGemm and disable tests with MaskOutUpperTriangle==True. Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> --------- Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> Co-authored-by: Mirza Halilcevic <mirza.halilcevic@amd.com>
This commit is contained in:
@@ -15,23 +15,24 @@ namespace device_batched_gemm_softmax_gemm {
|
||||
// defines the problem specification for a GEMM operation
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
bool MaskOutUpperTriangle = false;
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
Reference in New Issue
Block a user