81 lines
1.8 KiB
Python
81 lines
1.8 KiB
Python
|
|
"""Redis cache wrapper with namespaced keys."""
|
||
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import redis.asyncio as redis_async
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
|
||
|
|
_redis: redis_async.Redis | None = None
|
||
|
|
NAMESPACE = "search:"
|
||
|
|
|
||
|
|
|
||
|
|
async def init_redis() -> None:
|
||
|
|
global _redis
|
||
|
|
_redis = redis_async.Redis(
|
||
|
|
host=settings.redis_host,
|
||
|
|
port=settings.redis_port,
|
||
|
|
db=settings.redis_db,
|
||
|
|
decode_responses=True,
|
||
|
|
socket_timeout=2,
|
||
|
|
)
|
||
|
|
await _redis.ping()
|
||
|
|
|
||
|
|
|
||
|
|
async def close_redis() -> None:
|
||
|
|
global _redis
|
||
|
|
if _redis is not None:
|
||
|
|
await _redis.close()
|
||
|
|
_redis = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_redis() -> redis_async.Redis:
|
||
|
|
if _redis is None:
|
||
|
|
raise RuntimeError("Redis not initialized")
|
||
|
|
return _redis
|
||
|
|
|
||
|
|
|
||
|
|
def _hash(value: str) -> str:
|
||
|
|
return hashlib.sha1(value.encode("utf-8")).hexdigest()[:16]
|
||
|
|
|
||
|
|
|
||
|
|
def key_query(query: str) -> str:
|
||
|
|
return f"{NAMESPACE}q:{_hash(query.lower())}"
|
||
|
|
|
||
|
|
|
||
|
|
def key_suggest(prefix: str) -> str:
|
||
|
|
return f"{NAMESPACE}suggest:{_hash(prefix.lower())}"
|
||
|
|
|
||
|
|
|
||
|
|
def key_empty(query: str) -> str:
|
||
|
|
return f"{NAMESPACE}empty:{_hash(query.lower())}"
|
||
|
|
|
||
|
|
|
||
|
|
def key_top_brands() -> str:
|
||
|
|
return f"{NAMESPACE}top:brands"
|
||
|
|
|
||
|
|
|
||
|
|
async def get_json(key: str) -> Any | None:
|
||
|
|
data = await get_redis().get(key)
|
||
|
|
if data is None:
|
||
|
|
return None
|
||
|
|
try:
|
||
|
|
return json.loads(data)
|
||
|
|
except (json.JSONDecodeError, TypeError):
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def set_json(key: str, value: Any, ttl: int) -> None:
|
||
|
|
await get_redis().set(key, json.dumps(value, default=str), ex=ttl)
|
||
|
|
|
||
|
|
|
||
|
|
async def invalidate_all() -> int:
|
||
|
|
"""Delete all keys in our namespace. Returns count deleted."""
|
||
|
|
r = get_redis()
|
||
|
|
count = 0
|
||
|
|
async for key in r.scan_iter(match=f"{NAMESPACE}*", count=200):
|
||
|
|
await r.delete(key)
|
||
|
|
count += 1
|
||
|
|
return count
|