mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
debug mixed_prec flatmm
This commit is contained in:
@@ -28,10 +28,60 @@ struct FlatmmProblem
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
template <int SharedGranularity>
|
||||
template <int SharedGranularityMN, int SharedGranularityK = 0>
|
||||
struct FlatmmScalePointer
|
||||
{
|
||||
static constexpr int granularity = SharedGranularity;
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = SharedGranularityK;
|
||||
|
||||
const float* ptr;
|
||||
index_t scale_stride = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t stride)
|
||||
: ptr(ptr_), scale_stride(stride)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
// if constexpr(GranularityMN == 0)
|
||||
// {
|
||||
// ret.scalar = scalar;
|
||||
// }
|
||||
// else if constexpr(GranularityMN == 1)
|
||||
// {
|
||||
// ret.ptr = ptr + offset;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// ret.ptr = ptr + offset / GranularityMN;
|
||||
// }
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
if constexpr(GranularityMN == 1)
|
||||
{
|
||||
return ptr[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
return ptr[i / GranularityMN];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int SharedGranularityMN>
|
||||
struct FlatmmScalePointer<SharedGranularityMN, 0>
|
||||
{
|
||||
static constexpr int GranularityMN = SharedGranularityMN;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
static_assert(GranularityMN != 0);
|
||||
|
||||
union
|
||||
{
|
||||
@@ -42,50 +92,63 @@ struct FlatmmScalePointer
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t stride)
|
||||
: ptr(ptr_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
|
||||
{
|
||||
FlatmmScalePointer ret;
|
||||
if constexpr(granularity == 0)
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
ret.scalar = scalar;
|
||||
}
|
||||
else if constexpr(granularity == 1)
|
||||
else if constexpr(GranularityMN == 1)
|
||||
{
|
||||
ret.ptr = ptr + offset;
|
||||
}
|
||||
else
|
||||
{
|
||||
ret.ptr = ptr + offset / granularity;
|
||||
ret.ptr = ptr + offset / GranularityMN;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
|
||||
|
||||
CK_TILE_HOST_DEVICE float operator[](index_t i) const
|
||||
{
|
||||
if constexpr(granularity == 0)
|
||||
if constexpr(GranularityMN == 0)
|
||||
{
|
||||
return scalar;
|
||||
}
|
||||
else if constexpr(granularity == 1)
|
||||
else if constexpr(GranularityMN == 1)
|
||||
{
|
||||
return ptr[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
return ptr[i / granularity];
|
||||
return ptr[i / GranularityMN];
|
||||
}
|
||||
}
|
||||
};
|
||||
// shared granularity = -1 means no scale
|
||||
|
||||
// shared granularityMN = -1 means no scale
|
||||
template <>
|
||||
struct FlatmmScalePointer<-1>
|
||||
struct FlatmmScalePointer<-1, 0>
|
||||
{
|
||||
static constexpr int granularity = -1;
|
||||
static constexpr int GranularityMN = -1;
|
||||
static constexpr int GranularityK = 0;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float scalar_) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float* ptr_) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, [[maybe_unused]] index_t stride)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
|
||||
{
|
||||
@@ -150,7 +213,6 @@ struct BaseFlatmmHostArgs
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>,
|
||||
class ScaleN = FlatmmScalePointer<-1>,
|
||||
index_t NumDTensor = 0>
|
||||
@@ -558,9 +620,9 @@ struct FlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (kargs.K /
|
||||
BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
index_t kFlatK =
|
||||
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
@@ -776,7 +838,7 @@ struct FlatmmKernel
|
||||
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(ScaleM::granularity != -1 || ScaleN::granularity != -1)
|
||||
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
|
||||
Reference in New Issue
Block a user