from io import BytesIO
from time import time
from types import SimpleNamespace
from typing import Literal, Any, Tuple, Coroutine, Dict, Union
import msgpack
import numpy as np
import redis
from redis import ResponseError
from redis.commands.search.query import Query
from redis.commands.search.result import Result
ZType = Literal["numpy.ndarray", "torch.Tensor", "generic"]
[docs]class ZData:
# data_types = {
# "numpy.ndarray":
# }
[docs] @staticmethod
def encode(data: Union["torch.Tensor", np.ndarray]):
"""This converts arrays and tensors to z-format."""
import torch
T = type(data)
from PIL.Image import Image
if isinstance(data, Image):
# we always move to CPU
with BytesIO() as buffer:
# use the format of the Image object, default to PNG.
data.save(buffer, format=data.format or "PNG")
binary = buffer.getvalue()
return dict(ztype="image", b=binary)
elif T is np.ndarray:
# need to support other numpy array types, including mask.
binary = data.tobytes()
return dict(
ztype="numpy.ndarray",
b=binary,
dtype=str(data.dtype),
shape=data.shape,
)
elif T is torch.Tensor:
# we always move to CPU
np_v = data.cpu().numpy()
binary = np_v.tobytes()
return dict(
ztype="torch.Tensor",
b=binary,
dtype=str(np_v.dtype),
shape=np_v.shape,
)
else:
return data
# return dict(ztype="generic", b=data)
[docs] @staticmethod
def get_ztype(data: Dict) -> Union[ZType, None]:
"""check if it is z-payload"""
if type(data) is dict and "ztype" in data:
return data["ztype"]
[docs] @staticmethod
def decode(zdata):
import torch
T = ZData.get_ztype(zdata)
if not T:
return zdata
elif T == "image":
from PIL import Image
buff = BytesIO(zdata['b'])
image = Image.open(buff)
return image
elif T == "numpy.ndarray":
# need to support other numpy array types, including mask.
array = np.frombuffer(zdata["b"], dtype=zdata["dtype"])
array = array.reshape(zdata["shape"])
return array
elif T == "torch.Tensor":
array = np.frombuffer(zdata["b"], dtype=zdata["dtype"])
# we copy the array because the buffered version is non-writable.
array = array.reshape(zdata["shape"]).copy()
torch_array = torch.Tensor(array)
return torch_array
else:
raise TypeError(f"ZData type {T} is not supported")
[docs]class Payload(SimpleNamespace):
# class attributes are not serialized.
greedy = True
"""Set to False to avoid greedy convertion, and make it go faster"""
[docs] def __init__(self, _greedy=None, **payload):
if _greedy:
self.greedy = _greedy
super().__init__(**payload)
[docs] def serialize(self):
payload = self.__dict__
# we serialize components key value pairs
if self.greedy:
data = {k: ZData.encode(v) for k, v in payload.items()}
data["_greedy"] = self.greedy
msg = msgpack.packb(data, use_bin_type=True)
else:
msg = msgpack.packb(payload, use_bin_type=True)
return msg
[docs] @staticmethod
def deserialize(payload) -> Dict:
unpacked = msgpack.unpackb(payload, raw=False)
is_greedy = unpacked.pop("_greedy", None)
if not is_greedy:
return unpacked
else:
data = {}
for k, v in unpacked.items():
data[k] = ZData.decode(v)
return data
[docs]class Job(SimpleNamespace):
created_ts: float
status: Literal[None, "in_progress", "created"] = "created"
grab_ts: float = None
# value: Any = None
# payload: bytes = None
# """This is the binary encoding from the msgpack. """
ttl: float = None
[docs] @staticmethod
async def create_queue(r: redis.asyncio.Redis, name, *, prefix, smart=True):
from redis.commands.search.field import TagField, NumericField
from redis.commands.search.indexDefinition import IndexType, IndexDefinition
index_name = f"{prefix}:{name}"
index_prefix = f"{prefix}:{name}:"
print("creating queue:", index_name)
schema = (
NumericField("$.created_ts", as_name="created_ts"),
TagField("$.status", as_name="status"),
NumericField("$.grab_ts", as_name="grab_ts"),
# TextField("$.value", as_name="value"),
)
try:
await r.ft(index_name).create_index(
schema,
definition=IndexDefinition(
prefix=[index_prefix],
index_type=IndexType.JSON,
),
)
except ResponseError:
if not smart:
return
await r.ft(index_name).dropindex()
await r.ft(index_name).create_index(
schema,
definition=IndexDefinition(
prefix=[index_prefix],
index_type=IndexType.JSON,
),
)
[docs] @staticmethod
async def remove_queue(r: redis.asyncio.Redis, queue, *, prefix):
index_name = f"{prefix}:{queue}"
return await r.ft(index_name).dropindex()
[docs] @staticmethod
def add(
r: redis.asyncio.Redis,
queue: str,
*,
prefix: str,
# value: Any,
payload: bytes = None,
job_id: str = None,
ttl: float = None,
) -> Coroutine:
from uuid import uuid4
job = Job(
created_ts=time(),
status="created",
# value=value,
ttl=ttl,
)
if job_id is None:
job_id = str(uuid4())
entry_key = f"{prefix}:{queue}:{job_id}"
p = r.pipeline()
p.json().set(entry_key, ".", vars(job))
if payload:
p.set(entry_key + ".payload", payload)
return p.execute()
[docs] @staticmethod
async def take(r: redis.asyncio.Redis, queue, *, prefix) -> Tuple[str, Any]:
index_name = f"{prefix}:{queue}"
# note: search ranks results via FTIDF. Use aggregation to sort by created_ts
q = Query("@status: { created }").paging(0, 1)
result: Result = await r.ft(index_name).search(q)
if not result.total:
return None, None
job = result.docs[0]
p = r.pipeline()
payload, *_ = (
await p.get(job.id + ".payload")
.json()
.set(job.id, "$.status", "in_progress")
.json()
.set(job.id, "$.grab_ts", time())
.execute()
)
job_id = job.id[len(index_name) + 1 :]
return job_id, payload
[docs] @staticmethod
def remove(r: redis.asyncio.Redis, job_id, queue, *, prefix) -> Coroutine:
entry_name = f"{prefix}:{queue}:{job_id}"
p = r.pipeline()
response = p.json().delete(entry_name).delete(entry_name + ".payload").execute()
return response
[docs] @staticmethod
def reset(r: redis.asyncio.Redis, job_id, queue, *, prefix):
entry_name = f"{prefix}:{queue}:{job_id}"
p = r.pipeline()
p = p.json().set(entry_name, "$.status", "created")
p = p.json().set(entry_name, "$.grab_ts", None)
return p.execute()
[docs] @staticmethod
def timeout(r: redis.asyncio.Redis, queue, *, prefix, ttl=None):
index_name = f"{prefix}:{queue}"
if ttl:
result = r.ft(index_name).search(
"@status: { in_progress } @grab_ts: < {time() - ttl}"
)
else:
result = r.ft(index_name).search("@status: { in_progress }")
p = r.pipeline()
for doc in result.docs:
p = p.json().set(doc.id, "$.status", "created")
p = p.json().delete(doc.id, "$.grab_ts", None)
return p.execute()