mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Introduce gemm_softmax_gemm to codegen.
This commit is contained in:
@@ -7,8 +7,10 @@
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "ck/tensor_description/tensor_adaptor.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <limits>
|
||||
#include <stdlib.h>
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -979,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
|
||||
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
return std::make_tuple(N0, M0, k_split);
|
||||
return ck::make_tuple(N0, M0, k_split);
|
||||
}
|
||||
|
||||
template <typename TopIdx>
|
||||
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
|
||||
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
ck::NumericLimits<int>::Max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
|
||||
@@ -475,9 +475,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
|
||||
template <typename DsLayout, GemmSpecialization GemmSpec>
|
||||
__host__ __device__ static auto
|
||||
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
|
||||
const std::array<index_t, NumDTensor>& NRaws,
|
||||
const std::array<index_t, NumDTensor>& DsStride)
|
||||
MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
|
||||
const Array<index_t, NumDTensor>& NRaws,
|
||||
const Array<index_t, NumDTensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -941,7 +941,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const Array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#endif
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
|
||||
@@ -53,12 +55,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
{
|
||||
switch(p)
|
||||
@@ -71,3 +76,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
|
||||
}
|
||||
return os;
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user