Files
sd-webui-old-photo-restoration/Global/detection_models/sync_batchnorm/comm.py
Haoming 89a8626838 Squashed commit of the following:
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b
Author: Haoming <hmstudy02@gmail.com>
Date:   Tue Dec 19 11:33:44 2023 +0800

    add README

commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c
Author: Haoming <hmstudy02@gmail.com>
Date:   Tue Dec 19 11:12:16 2023 +0800

    fix Face Enhancement distortion

commit 6d52de5368c6cfbd9342465b5238725c186e00b9
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 18:27:25 2023 +0800

    better? args handling

commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 17:40:19 2023 +0800

    bug fix related to Scratch

commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 17:24:52 2023 +0800

    implement step 2 ~ 4

commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 11:55:20 2023 +0800

    process scratch

commit 3b18f7b042
Author: Haoming <hmstudy02@gmail.com>
Date:   Wed Dec 13 11:57:20 2023 +0800

    "init"

commit d0148e0e82
Author: Haoming <hmstudy02@gmail.com>
Date:   Wed Dec 13 10:34:39 2023 +0800

    clone repo
2023-12-19 11:35:38 +08:00

138 lines
4.3 KiB
Python

# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)