You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

348 lines
10 KiB

#
# Copyright 2025 Splunk Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""A simple thread pool implementation."""
import multiprocessing
import queue
import threading
import traceback
from time import time
import logging
class ThreadPool:
"""A simple thread pool implementation."""
_high_watermark = 0.2
_resize_window = 10
def __init__(self, min_size=1, max_size=128, task_queue_size=1024, daemon=True):
assert task_queue_size
if not min_size or min_size <= 0:
min_size = multiprocessing.cpu_count()
if not max_size or max_size <= 0:
max_size = multiprocessing.cpu_count() * 8
self._min_size = min_size
self._max_size = max_size
self._daemon = daemon
self._work_queue = queue.Queue(task_queue_size)
self._thrs = []
for _ in range(min_size):
thr = threading.Thread(target=self._run)
self._thrs.append(thr)
self._admin_queue = queue.Queue()
self._admin_thr = threading.Thread(target=self._do_admin)
self._last_resize_time = time()
self._last_size = min_size
self._lock = threading.Lock()
self._occupied_threads = 0
self._count_lock = threading.Lock()
self._started = False
def start(self):
"""Start threads in the pool."""
with self._lock:
if self._started:
return
self._started = True
for thr in self._thrs:
thr.daemon = self._daemon
thr.start()
self._admin_thr.start()
logging.info("ThreadPool started.")
def tear_down(self):
"""Tear down thread pool."""
with self._lock:
if not self._started:
return
self._started = False
for thr in self._thrs:
self._work_queue.put(None, block=True)
self._admin_queue.put(None)
if not self._daemon:
logging.info("Wait for threads to stop.")
for thr in self._thrs:
thr.join()
self._admin_thr.join()
logging.info("ThreadPool stopped.")
def enqueue_funcs(self, funcs, block=True):
"""run jobs in a fire and forget way, no result will be handled over to
clients.
:param funcs: tuple/list-like or generator like object, func shall be
callable
"""
if not self._started:
logging.info("ThreadPool has already stopped.")
return
for func in funcs:
self._work_queue.put(func, block)
def apply_async(self, func, args=(), kwargs=None, callback=None):
"""
:param func: callable
:param args: free params
:param kwargs: named params
:callback: when func is done and without exception, call the callback
:return AsyncResult, clients can poll or wait the result through it
"""
if not self._started:
logging.info("ThreadPool has already stopped.")
return None
res = AsyncResult(func, args, kwargs, callback)
self._work_queue.put(res)
return res
def apply(self, func, args=(), kwargs=None):
"""
:param func: callable
:param args: free params
:param kwargs: named params
:return whatever the func returns
"""
if not self._started:
logging.info("ThreadPool has already stopped.")
return None
res = self.apply_async(func, args, kwargs)
return res.get()
def size(self):
return self._last_size
def resize(self, new_size):
"""Resize the pool size, spawn or destroy threads if necessary."""
if new_size <= 0:
return
if self._lock.locked() or not self._started:
logging.info(
"Try to resize thread pool during the tear " "down process, do nothing"
)
return
with self._lock:
self._remove_exited_threads_with_lock()
size = self._last_size
self._last_size = new_size
if new_size > size:
for _ in range(new_size - size):
thr = threading.Thread(target=self._run)
thr.daemon = self._daemon
thr.start()
self._thrs.append(thr)
elif new_size < size:
for _ in range(size - new_size):
self._work_queue.put(None)
logging.info("Finished ThreadPool resizing. New size=%d", new_size)
def _remove_exited_threads_with_lock(self):
"""Join the exited threads last time when resize was called."""
joined_thrs = set()
for thr in self._thrs:
if not thr.is_alive():
try:
if not thr.daemon:
thr.join(timeout=0.5)
joined_thrs.add(thr.ident)
except RuntimeError:
pass
if joined_thrs:
live_thrs = []
for thr in self._thrs:
if thr.ident not in joined_thrs:
live_thrs.append(thr)
self._thrs = live_thrs
def _do_resize_according_to_loads(self):
if (
self._last_resize_time
and time() - self._last_resize_time < self._resize_window
):
return
thr_size = self._last_size
free_thrs = thr_size - self._occupied_threads
work_size = self._work_queue.qsize()
logging.debug(
"current_thr_size=%s, free_thrs=%s, work_size=%s",
thr_size,
free_thrs,
work_size,
)
if work_size and work_size > free_thrs:
if thr_size < self._max_size:
thr_size = min(thr_size * 2, self._max_size)
self.resize(thr_size)
elif free_thrs > 0:
free = free_thrs * 1.0
if free / thr_size >= self._high_watermark and free_thrs >= 2:
# 20 % thrs are idle, tear down half of the idle ones
thr_size = thr_size - int(free_thrs // 2)
if thr_size > self._min_size:
self.resize(thr_size)
self._last_resize_time = time()
def _do_admin(self):
admin_q = self._admin_queue
resize_win = self._resize_window
while 1:
try:
wakup = admin_q.get(timeout=resize_win + 1)
except queue.Empty:
self._do_resize_according_to_loads()
continue
if wakup is None:
break
else:
self._do_resize_according_to_loads()
logging.info(
"ThreadPool admin thread=%s stopped.", threading.current_thread().getName()
)
def _run(self):
"""Threads callback func, run forever to handle jobs from the job
queue."""
work_queue = self._work_queue
count_lock = self._count_lock
while 1:
logging.debug("Going to get job")
func = work_queue.get()
if func is None:
break
if not self._started:
break
logging.debug("Going to exec job")
with count_lock:
self._occupied_threads += 1
try:
func()
except Exception:
logging.error(traceback.format_exc())
with count_lock:
self._occupied_threads -= 1
logging.debug("Done with exec job")
logging.info("Thread work_queue_size=%d", work_queue.qsize())
logging.debug("Worker thread %s stopped.", threading.current_thread().getName())
class AsyncResult:
def __init__(self, func, args, kwargs, callback):
self._func = func
self._args = args
self._kwargs = kwargs
self._callback = callback
self._q = queue.Queue()
def __call__(self):
try:
if self._args and self._kwargs:
res = self._func(*self._args, **self._kwargs)
elif self._args:
res = self._func(*self._args)
elif self._kwargs:
res = self._func(**self._kwargs)
else:
res = self._func()
except Exception as e:
self._q.put(e)
return
else:
self._q.put(res)
if self._callback is not None:
self._callback()
def get(self, timeout=None):
"""Return the result when it arrives.
If timeout is not None and the result does not arrive within
timeout seconds then multiprocessing.TimeoutError is raised. If
the remote call raised an exception then that exception will be
reraised by get().
"""
try:
res = self._q.get(timeout=timeout)
except queue.Empty:
raise multiprocessing.TimeoutError("Timed out")
if isinstance(res, Exception):
raise res
return res
def wait(self, timeout=None):
"""Wait until the result is available or until timeout seconds pass."""
try:
res = self._q.get(timeout=timeout)
except queue.Empty:
pass
else:
self._q.put(res)
def ready(self):
"""Return whether the call has completed."""
return len(self._q)
def successful(self):
"""Return whether the call completed without raising an exception.
Will raise AssertionError if the result is not ready.
"""
if not self.ready():
raise AssertionError("Function is not ready")
res = self._q.get()
self._q.put(res)
if isinstance(res, Exception):
return False
return True