import msgpack
import redis
from aiohttp import web
from params_proto import Proto, ParamsProto, Flag
from zaku.base import Server
from zaku.interfaces import Job
[docs]class Redis(ParamsProto, prefix="redis", cli_parse=False):
"""Redis Configuration for the TaskServer class.
Setting keepalive: ConnectionError is due to timeout. See below for more infomation.
https://devcenter.heroku.com/articles/ah-redis-stackhero#:~:text=The%20error%20%E2%80%9Credis.,and%20the%20connection%20closes%20automatically
.. code-block:: shell
# Put this into an .env file
REDIS_HOST=localhost
REDIS_PORT=6379
REDIS_PASSWORD=xxxxxxxxxxxxxxxxxxx
REDIS_DB=0
CLI Options::
--redis.host :str 'localhost'
--redis.port :int 6379
--redis.password :any None
--redis.db :any 0
"""
host: str = Proto("localhost", env="REDIS_HOST")
port: int = Proto(6379, env="REDIS_PORT")
password: str = Proto(env="REDIS_PASSWORD")
# todo: add support for cluster mode without sentinel.
sentinel_hosts: str = Proto(
env="SENTINEL_HOSTS", help="comma separated list of redis hosts. Example: host1:port1,host2:port2,host3:port3."
)
sentinel_password: str = Proto(password, env="SENTINEL_PASSWORD")
cluster_name = Proto("primary", env="SENTINEL_CLUSTER_NAME")
shuffle: bool = Proto(False, env="REDIS_SHUFFLE", help="shuffle the sentinel hosts on init.")
db: int = Proto(0, env="REDIS_DB", help="The logical redis database, from 0 - 15.")
# connection arguments
health_check_interval = 10
socket_connect_timeout = 5
retry_on_timeout = True
socket_keepalive = True
def __post_init__(self, _deps=None):
if self.sentinel_hosts:
import random
self.sentinel_hosts = [h.split(":") for h in self.sentinel_hosts.split(",")]
if self.shuffle:
self.sentinel_hosts = random.shuffle(self.sentinel_hosts, key=lambda x: hash(x))
self.sentinel = redis.asyncio.sentinel.Sentinel(
self.sentinel_hosts,
# redis_class=RobustRedis,
password=self.sentinel_password,
health_check_interval=self.health_check_interval,
socket_connect_timeout=self.socket_connect_timeout,
retry_on_timeout=self.retry_on_timeout,
socket_keepalive=self.socket_keepalive,
)
self.connection = self.sentinel.master_for(self.cluster_name, password=self.password, db=self.db)
else:
# self.connection = RobustRedis(password=self.password, db=self.db, host=self.host, port=self.port)
self.connection = redis.asyncio.Redis(
password=self.password,
db=self.db,
host=self.host,
port=self.port,
health_check_interval=self.health_check_interval,
socket_connect_timeout=self.socket_connect_timeout,
retry_on_timeout=self.retry_on_timeout,
socket_keepalive=self.socket_keepalive,
)
[docs]class TaskServer(ParamsProto, Server):
"""TaskServer
This is the server that maintains the Task Queue.
Usage
.. code-block:: python
app = TaskServer()
app.run()
CLI Options
.. code-block:: shell
python -m zaku.server --help
-h, --help show this help message and exit
--prefix :str 'Zaku-task-queues'
--queue-len :int 100
--host :str 'localhost'
set to 0.0.0.0 to enable remote (not localhost) connections.
--port :int 9000
--cors :str 'https://vuer.ai,https://dash.ml,http://lo...
--cert :str None the path to the SSL certificate
--key :str None the path to the SSL key
--ca-cert :str None the trusted root CA certificates
--REQUEST-MAX-SIZE :int 100000000 the maximum packet size
--free-port :bool False
kill process squatting target port if True.
--static-root :str '.'
--verbose :bool False
show the list of configurations during launch if True.
"""
prefix = "Zaku-task-queues"
# Queue Parameters
queue_len = 100 # use a max length to avoid the memory from blowing up.
# Server Parameters
host: str = Proto(
"localhost",
help="set to 0.0.0.0 to enable remote (not localhost) connections.",
)
port: int = 9000
cors: str = "https://vuer.ai,https://dash.ml,http://localhost:8000,http://127.0.0.1:8000,*"
# SSL Parameters
cert: str = Proto(None, dtype=str, help="the path to the SSL certificate")
key: str = Proto(None, dtype=str, help="the path to the SSL key")
ca_cert: str = Proto(None, dtype=str, help="the trusted root CA certificates")
REQUEST_MAX_SIZE: int = Proto(100_000_000, env="WEBSOCKET_MAX_SIZE", help="the maximum packet size")
free_port: bool = Flag("kill process squatting target port if True.")
static_root: str = "."
verbose = Flag("show the list of configurations during launch if True.")
def print_info(self):
print("========= Arguments =========")
for k, v in vars(self).items():
print(f" {k} = {v},")
print("-----------------------------")
def __post_init__(self, _deps=None):
if self.verbose:
self.print_info()
Server.__post_init__(self)
self.redis_wrapper = Redis(_deps)
self.redis = self.redis_wrapper.connection
async def create_queue(self, request: web.Request):
data = await request.json()
try:
await Job.create_queue(self.redis, **data, prefix=self.prefix)
except Exception as e:
# the error handling here is not ideal.
return web.Response(text="ERROR: " + str(e), status=200)
return web.Response(text="OK")
async def add_job(self, request: web.Request):
msg = await request.read()
data = msgpack.unpackb(msg)
# print("==>", data)
await Job.add(self.redis, prefix=self.prefix, **data)
return web.Response(text="OK")
async def publish_job(self, request: web.Request):
msg = await request.read()
data = msgpack.unpackb(msg)
# print("==>", data)
num_subscribers = await Job.publish(self.redis, prefix=self.prefix, **data)
# todo: return the number of subscribers.
return web.Response(text=str(num_subscribers), status=200)
async def reset_handler(self, request: web.Request):
data = await request.json()
# print("==>", data)
await Job.reset(self.redis, **data, prefix=self.prefix)
return web.Response(text="OK")
async def remove_handler(self, request: web.Request):
data = await request.json()
# print("remove ==>", data)
await Job.remove(self.redis, **data, prefix=self.prefix)
return web.Response(text="OK")
async def count_files_handler(self, request):
data = await request.json()
try:
counts = await Job.count_files(self.redis, **data, prefix=self.prefix)
except redis.exceptions.ResponseError as e:
if "no such index" in str(e):
return web.Response(status=200)
# return web.Response(text="no such index", status=404)
raise e
msg = msgpack.packb({"counts": counts}, use_bin_type=True)
return web.Response(body=msg, status=200)
async def take_handler(self, request):
data = await request.json()
try:
job_id, payload = await Job.take(self.redis, **data, prefix=self.prefix)
except redis.exceptions.ResponseError as e:
if "no such index" in str(e):
return web.Response(status=200)
# return web.Response(text="no such index", status=404)
raise e
if payload:
msg = msgpack.packb({"job_id": job_id, "payload": payload}, use_bin_type=True)
return web.Response(body=msg, status=200)
return web.Response(status=200)
async def unstale_handler(self, request: web.Request):
data = await request.json()
# print("take ==> data", data)
await Job.unstale_tasks(self.redis, **data, prefix=self.prefix)
return web.Response(text="OK", status=200)
async def subscribe_one_handler(self, request):
data = await request.json()
payload = await Job.subscribe(self.redis, **data, prefix=self.prefix)
if payload:
return web.Response(body=payload, status=200)
return web.Response(status=200)
async def subscribe_streaming_handler(self, request) -> web.StreamResponse:
data = await request.json()
async def stream_response(response):
try:
async for payload in Job.subscribe_stream(self.redis, **data, prefix=self.prefix):
# use msgpack.Unpacker to determin the end of message.
await response.write(payload)
await response.write_eof()
except ConnectionResetError:
print("client disconnected.")
return response
return response
response = web.StreamResponse(
status=200,
reason="OK",
headers={"Content-Type": "text/plain"},
)
await response.prepare(request)
await stream_response(response)
return response
def setup_server(self):
import os
# use the same endpoint for websocket and file serving.
self._route("/queues", self.create_queue, method="PUT")
self._route("/tasks", self.add_job, method="PUT")
self._route("/tasks", self.take_handler, method="POST")
self._route("/tasks/counts", self.count_files_handler, method="GET")
self._route("/tasks/reset", self.reset_handler, method="POST")
self._route("/tasks", self.remove_handler, method="DELETE")
self._route("/tasks/unstale", self.unstale_handler, method="PUT")
self._route("/publish", self.publish_job, method="PUT")
self._route("/subscribe_one", self.subscribe_one_handler, method="POST")
self._route("/subscribe_stream", self.subscribe_streaming_handler, method="POST")
# serve local files via /static endpoint
self._static("/static", self.static_root)
print("Serving file://" + os.path.abspath(self.static_root), "at", "/static")
return self.app
def run(self, kill=None, *args, **kwargs):
if kill or self.free_port:
import time
from killport import kill_ports
kill_ports(ports=[self.port])
time.sleep(0.01)
self.setup_server()
print("redis server is now connected")
print(f"serving at http://localhost:{self.port}")
super().run()
[docs]def entry_point():
TaskServer().run()
[docs]async def get_app():
"""This is sued by gunicorn to run the server in worker mode."""
return TaskServer().setup_server()
if __name__ == "__main__":
entry_point()