mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
Gemm+layernorm instance, ckProfiler, client example (#568)
* Add gemm + layernorm instance
* Add ckProfiler
* Add test
* Add client example
* Detect if user forger to set the workrspace
* Use literal in the example
* [What] use builtin function for sqrt
[Why] compiler will not use v_sqrt_f64_e64 if we use ::sqrt()
* check gemm vaildity in IsSupportedArgument
* Add more testcases
* Merge duplicated folder in client example
* Print more infomation
* Use better kernel parameter for MS problem size
* clang format
* Add constexpr for if condition and remove redundant include
* Remove cstdlib and add constexpr
[ROCm/composable_kernel commit: f7d28f3e4b]
This commit is contained in:
@@ -669,6 +669,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
|
||||
}
|
||||
if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
|
||||
arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
|
||||
throw std::runtime_error("wrong! WorkSpace pointer has not been set");
|
||||
|
||||
index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
|
||||
|
||||
@@ -939,7 +942,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.gemm_e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -1055,7 +1062,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
|
||||
<< GemmKPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< PostShuffleThreadClusterSize_M_N::At(I0) << ", "
|
||||
<< PostShuffleThreadClusterSize_M_N::At(I1) << ", "
|
||||
<< LayernormThreadClusterSize_M_N::At(I0) << ", "
|
||||
<< LayernormThreadClusterSize_M_N::At(I1) << ", "
|
||||
<< LayernormThreadSliceSize_M
|
||||
<< ">"
|
||||
<< " LoopScheduler: "
|
||||
<< LoopSchedToString[LoopSched] << ", "
|
||||
|
||||
@@ -158,9 +158,9 @@ static inline __device__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
|
||||
static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
|
||||
|
||||
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
|
||||
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user