mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
MIGraphX hipRTC fix (#1923)
* fixed hiprtc compilation issues from new additions, removed clashing mixed precision functionality from codegen(ignore the whole file) * fixed device op error: misplaced header guard * restrict virtual function use in device_gemm_multiple_d file for codegen hiprtc compilation * add CK_CODE_GEN_RTC flag for compilation, since this flag has wider coverage for hiprtc compilation * fixed conditional error in amd_ck_fp8.hpp * Add MaskOutUpperTriangle as a problem parameter to BatchedGemmSoftmaxGemm and disable tests with MaskOutUpperTriangle==True. Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> --------- Signed-off-by: Mirza Halilcevic <mirza.halilcevic@amd.com> Co-authored-by: Mirza Halilcevic <mirza.halilcevic@amd.com>
This commit is contained in:
@@ -15,23 +15,24 @@ namespace device_batched_gemm_softmax_gemm {
|
||||
// defines the problem specification for a GEMM operation
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
bool MaskOutUpperTriangle = false;
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
@@ -259,10 +259,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
x.tile_desc.gemm1_n_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
x.mask_out_upper_triangle = true;
|
||||
result.push_back(x);
|
||||
|
||||
x.mask_out_upper_triangle = false;
|
||||
x.mask_out_upper_triangle = prob.MaskOutUpperTriangle;
|
||||
result.push_back(x);
|
||||
}
|
||||
return result;
|
||||
@@ -273,13 +270,20 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Problem> problems;
|
||||
|
||||
Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
problems.push_back(prob);
|
||||
|
||||
return {CreateOperations(prob, prologue, epilogue)};
|
||||
prob.MaskOutUpperTriangle = true;
|
||||
problems.push_back(prob);
|
||||
|
||||
return Transform(problems,
|
||||
[&](const Problem& p) { return CreateOperations(p, prologue, epilogue); });
|
||||
}
|
||||
|
||||
static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
|
||||
|
||||
@@ -42,7 +42,7 @@ TEST_CASE(test_problem_kernel)
|
||||
prob.K = 1024;
|
||||
prob.O = 1024;
|
||||
prob.TransB = true;
|
||||
check_all<half> check1, check2;
|
||||
check_all<half> check;
|
||||
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
|
||||
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
|
||||
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
|
||||
@@ -77,10 +77,8 @@ TEST_CASE(test_problem_kernel)
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(
|
||||
a.data(), b.data(), b1.data(), c.data());
|
||||
|
||||
if(solution.GetTemplateParameter<bool>("MaskOutUpperTriangle"))
|
||||
CHECK(report(solution, check1(rtc::from_gpu(c))));
|
||||
else
|
||||
CHECK(report(solution, check2(rtc::from_gpu(c))));
|
||||
// NOTE: Solutions where MaskOutUpperTriangle is True don't produce consistent results
|
||||
CHECK(report(solution, check(rtc::from_gpu(c))));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -279,6 +279,7 @@ static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_o
|
||||
{
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " -DCK_CODE_GEN_RTC";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options);
|
||||
if(cos.size() != 1)
|
||||
|
||||
Reference in New Issue
Block a user