You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
348 lines
10 KiB
348 lines
10 KiB
#
|
|
# Copyright 2025 Splunk Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""A simple thread pool implementation."""
|
|
|
|
import multiprocessing
|
|
import queue
|
|
import threading
|
|
import traceback
|
|
from time import time
|
|
import logging
|
|
|
|
|
|
class ThreadPool:
|
|
"""A simple thread pool implementation."""
|
|
|
|
_high_watermark = 0.2
|
|
_resize_window = 10
|
|
|
|
def __init__(self, min_size=1, max_size=128, task_queue_size=1024, daemon=True):
|
|
assert task_queue_size
|
|
|
|
if not min_size or min_size <= 0:
|
|
min_size = multiprocessing.cpu_count()
|
|
|
|
if not max_size or max_size <= 0:
|
|
max_size = multiprocessing.cpu_count() * 8
|
|
|
|
self._min_size = min_size
|
|
self._max_size = max_size
|
|
self._daemon = daemon
|
|
|
|
self._work_queue = queue.Queue(task_queue_size)
|
|
self._thrs = []
|
|
for _ in range(min_size):
|
|
thr = threading.Thread(target=self._run)
|
|
self._thrs.append(thr)
|
|
self._admin_queue = queue.Queue()
|
|
self._admin_thr = threading.Thread(target=self._do_admin)
|
|
self._last_resize_time = time()
|
|
self._last_size = min_size
|
|
self._lock = threading.Lock()
|
|
self._occupied_threads = 0
|
|
self._count_lock = threading.Lock()
|
|
self._started = False
|
|
|
|
def start(self):
|
|
"""Start threads in the pool."""
|
|
|
|
with self._lock:
|
|
if self._started:
|
|
return
|
|
self._started = True
|
|
|
|
for thr in self._thrs:
|
|
thr.daemon = self._daemon
|
|
thr.start()
|
|
|
|
self._admin_thr.start()
|
|
logging.info("ThreadPool started.")
|
|
|
|
def tear_down(self):
|
|
"""Tear down thread pool."""
|
|
|
|
with self._lock:
|
|
if not self._started:
|
|
return
|
|
self._started = False
|
|
|
|
for thr in self._thrs:
|
|
self._work_queue.put(None, block=True)
|
|
|
|
self._admin_queue.put(None)
|
|
|
|
if not self._daemon:
|
|
logging.info("Wait for threads to stop.")
|
|
for thr in self._thrs:
|
|
thr.join()
|
|
self._admin_thr.join()
|
|
|
|
logging.info("ThreadPool stopped.")
|
|
|
|
def enqueue_funcs(self, funcs, block=True):
|
|
"""run jobs in a fire and forget way, no result will be handled over to
|
|
clients.
|
|
|
|
:param funcs: tuple/list-like or generator like object, func shall be
|
|
callable
|
|
"""
|
|
|
|
if not self._started:
|
|
logging.info("ThreadPool has already stopped.")
|
|
return
|
|
|
|
for func in funcs:
|
|
self._work_queue.put(func, block)
|
|
|
|
def apply_async(self, func, args=(), kwargs=None, callback=None):
|
|
"""
|
|
:param func: callable
|
|
:param args: free params
|
|
:param kwargs: named params
|
|
:callback: when func is done and without exception, call the callback
|
|
:return AsyncResult, clients can poll or wait the result through it
|
|
"""
|
|
|
|
if not self._started:
|
|
logging.info("ThreadPool has already stopped.")
|
|
return None
|
|
|
|
res = AsyncResult(func, args, kwargs, callback)
|
|
self._work_queue.put(res)
|
|
return res
|
|
|
|
def apply(self, func, args=(), kwargs=None):
|
|
"""
|
|
:param func: callable
|
|
:param args: free params
|
|
:param kwargs: named params
|
|
:return whatever the func returns
|
|
"""
|
|
|
|
if not self._started:
|
|
logging.info("ThreadPool has already stopped.")
|
|
return None
|
|
|
|
res = self.apply_async(func, args, kwargs)
|
|
return res.get()
|
|
|
|
def size(self):
|
|
return self._last_size
|
|
|
|
def resize(self, new_size):
|
|
"""Resize the pool size, spawn or destroy threads if necessary."""
|
|
|
|
if new_size <= 0:
|
|
return
|
|
|
|
if self._lock.locked() or not self._started:
|
|
logging.info(
|
|
"Try to resize thread pool during the tear " "down process, do nothing"
|
|
)
|
|
return
|
|
|
|
with self._lock:
|
|
self._remove_exited_threads_with_lock()
|
|
size = self._last_size
|
|
self._last_size = new_size
|
|
if new_size > size:
|
|
for _ in range(new_size - size):
|
|
thr = threading.Thread(target=self._run)
|
|
thr.daemon = self._daemon
|
|
thr.start()
|
|
self._thrs.append(thr)
|
|
elif new_size < size:
|
|
for _ in range(size - new_size):
|
|
self._work_queue.put(None)
|
|
logging.info("Finished ThreadPool resizing. New size=%d", new_size)
|
|
|
|
def _remove_exited_threads_with_lock(self):
|
|
"""Join the exited threads last time when resize was called."""
|
|
|
|
joined_thrs = set()
|
|
for thr in self._thrs:
|
|
if not thr.is_alive():
|
|
try:
|
|
if not thr.daemon:
|
|
thr.join(timeout=0.5)
|
|
joined_thrs.add(thr.ident)
|
|
except RuntimeError:
|
|
pass
|
|
|
|
if joined_thrs:
|
|
live_thrs = []
|
|
for thr in self._thrs:
|
|
if thr.ident not in joined_thrs:
|
|
live_thrs.append(thr)
|
|
self._thrs = live_thrs
|
|
|
|
def _do_resize_according_to_loads(self):
|
|
if (
|
|
self._last_resize_time
|
|
and time() - self._last_resize_time < self._resize_window
|
|
):
|
|
return
|
|
|
|
thr_size = self._last_size
|
|
free_thrs = thr_size - self._occupied_threads
|
|
work_size = self._work_queue.qsize()
|
|
|
|
logging.debug(
|
|
"current_thr_size=%s, free_thrs=%s, work_size=%s",
|
|
thr_size,
|
|
free_thrs,
|
|
work_size,
|
|
)
|
|
if work_size and work_size > free_thrs:
|
|
if thr_size < self._max_size:
|
|
thr_size = min(thr_size * 2, self._max_size)
|
|
self.resize(thr_size)
|
|
elif free_thrs > 0:
|
|
free = free_thrs * 1.0
|
|
if free / thr_size >= self._high_watermark and free_thrs >= 2:
|
|
# 20 % thrs are idle, tear down half of the idle ones
|
|
thr_size = thr_size - int(free_thrs // 2)
|
|
if thr_size > self._min_size:
|
|
self.resize(thr_size)
|
|
self._last_resize_time = time()
|
|
|
|
def _do_admin(self):
|
|
admin_q = self._admin_queue
|
|
resize_win = self._resize_window
|
|
while 1:
|
|
try:
|
|
wakup = admin_q.get(timeout=resize_win + 1)
|
|
except queue.Empty:
|
|
self._do_resize_according_to_loads()
|
|
continue
|
|
|
|
if wakup is None:
|
|
break
|
|
else:
|
|
self._do_resize_according_to_loads()
|
|
logging.info(
|
|
"ThreadPool admin thread=%s stopped.", threading.current_thread().getName()
|
|
)
|
|
|
|
def _run(self):
|
|
"""Threads callback func, run forever to handle jobs from the job
|
|
queue."""
|
|
|
|
work_queue = self._work_queue
|
|
count_lock = self._count_lock
|
|
while 1:
|
|
logging.debug("Going to get job")
|
|
func = work_queue.get()
|
|
if func is None:
|
|
break
|
|
|
|
if not self._started:
|
|
break
|
|
|
|
logging.debug("Going to exec job")
|
|
with count_lock:
|
|
self._occupied_threads += 1
|
|
|
|
try:
|
|
func()
|
|
except Exception:
|
|
logging.error(traceback.format_exc())
|
|
|
|
with count_lock:
|
|
self._occupied_threads -= 1
|
|
|
|
logging.debug("Done with exec job")
|
|
logging.info("Thread work_queue_size=%d", work_queue.qsize())
|
|
|
|
logging.debug("Worker thread %s stopped.", threading.current_thread().getName())
|
|
|
|
|
|
class AsyncResult:
|
|
def __init__(self, func, args, kwargs, callback):
|
|
self._func = func
|
|
self._args = args
|
|
self._kwargs = kwargs
|
|
self._callback = callback
|
|
self._q = queue.Queue()
|
|
|
|
def __call__(self):
|
|
try:
|
|
if self._args and self._kwargs:
|
|
res = self._func(*self._args, **self._kwargs)
|
|
elif self._args:
|
|
res = self._func(*self._args)
|
|
elif self._kwargs:
|
|
res = self._func(**self._kwargs)
|
|
else:
|
|
res = self._func()
|
|
except Exception as e:
|
|
self._q.put(e)
|
|
return
|
|
else:
|
|
self._q.put(res)
|
|
|
|
if self._callback is not None:
|
|
self._callback()
|
|
|
|
def get(self, timeout=None):
|
|
"""Return the result when it arrives.
|
|
|
|
If timeout is not None and the result does not arrive within
|
|
timeout seconds then multiprocessing.TimeoutError is raised. If
|
|
the remote call raised an exception then that exception will be
|
|
reraised by get().
|
|
"""
|
|
|
|
try:
|
|
res = self._q.get(timeout=timeout)
|
|
except queue.Empty:
|
|
raise multiprocessing.TimeoutError("Timed out")
|
|
|
|
if isinstance(res, Exception):
|
|
raise res
|
|
return res
|
|
|
|
def wait(self, timeout=None):
|
|
"""Wait until the result is available or until timeout seconds pass."""
|
|
|
|
try:
|
|
res = self._q.get(timeout=timeout)
|
|
except queue.Empty:
|
|
pass
|
|
else:
|
|
self._q.put(res)
|
|
|
|
def ready(self):
|
|
"""Return whether the call has completed."""
|
|
|
|
return len(self._q)
|
|
|
|
def successful(self):
|
|
"""Return whether the call completed without raising an exception.
|
|
|
|
Will raise AssertionError if the result is not ready.
|
|
"""
|
|
|
|
if not self.ready():
|
|
raise AssertionError("Function is not ready")
|
|
res = self._q.get()
|
|
self._q.put(res)
|
|
|
|
if isinstance(res, Exception):
|
|
return False
|
|
return True
|