Add /recall endpoint, add Redis support
This commit is contained in:
parent
3efaddfae1
commit
b7c26536d1
72
main.py
72
main.py
@ -3,6 +3,29 @@ from pydantic import BaseModel
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import json
|
import json
|
||||||
|
import redis
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Initialize Redis connection
|
||||||
|
def get_redis_client():
|
||||||
|
redis_host = os.getenv('REDIS_HOST', 'localhost')
|
||||||
|
redis_port = int(os.getenv('REDIS_PORT', 6379))
|
||||||
|
redis_db = int(os.getenv('REDIS_DB', 0))
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = redis.StrictRedis(
|
||||||
|
host=redis_host,
|
||||||
|
port=redis_port,
|
||||||
|
db=redis_db,
|
||||||
|
decode_responses=True
|
||||||
|
)
|
||||||
|
# Test connection
|
||||||
|
client.ping()
|
||||||
|
return client
|
||||||
|
except redis.ConnectionError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=f"Could not connect to Redis: {str(e)}")
|
||||||
|
|
||||||
|
redis_client = get_redis_client()
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Remember Service",
|
title="Remember Service",
|
||||||
@ -19,18 +42,61 @@ class RememberRequest(BaseModel):
|
|||||||
|
|
||||||
@app.post("/remember",
|
@app.post("/remember",
|
||||||
summary="Store information to remember",
|
summary="Store information to remember",
|
||||||
description="Saves the provided content and optional metadata. This endpoint is designed to be called by an LLM as a tool.")
|
description="Saves the provided content and optional metadata in Redis.")
|
||||||
def remember_endpoint(request: RememberRequest):
|
def remember_endpoint(request: RememberRequest):
|
||||||
"""
|
"""
|
||||||
Endpoint to remember information.
|
Endpoint to remember information.
|
||||||
Prints the received payload to stdout in JSON format.
|
Stores the received payload in Redis and prints it to stdout in JSON format.
|
||||||
"""
|
"""
|
||||||
payload = {
|
payload = {
|
||||||
"content": request.content,
|
"content": request.content,
|
||||||
"metadata": request.metadata
|
"metadata": request.metadata
|
||||||
}
|
}
|
||||||
|
payload_json = json.dumps(payload, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Generate a unique key based on timestamp
|
||||||
|
import time
|
||||||
|
key = f"remember:{int(time.time())}"
|
||||||
|
|
||||||
|
# Store in Redis
|
||||||
|
redis_client.set(key, payload_json)
|
||||||
|
|
||||||
print(json.dumps(payload, indent=2, ensure_ascii=False))
|
print(json.dumps(payload, indent=2, ensure_ascii=False))
|
||||||
return {"status": "remembered"}
|
return {"status": "remembered", "key": key}
|
||||||
|
|
||||||
|
@app.get("/recall",
|
||||||
|
summary="Recall information previously remembered",
|
||||||
|
description="Searches for content in Redis based on query string. Returns matching entries.")
|
||||||
|
def recall_endpoint(query: str):
|
||||||
|
"""
|
||||||
|
Endpoint to recall information from Redis.
|
||||||
|
Searches for entries containing the query string in their content or metadata.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get all keys matching the remember pattern
|
||||||
|
keys = redis_client.keys("remember:*")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for key in keys:
|
||||||
|
data = redis_client.get(key)
|
||||||
|
if data:
|
||||||
|
payload = json.loads(data)
|
||||||
|
# Check if query is in content or metadata
|
||||||
|
if (isinstance(payload.get("content"), str) and query.lower() in payload["content"].lower()) or \
|
||||||
|
(isinstance(payload.get("metadata"), dict) and any(
|
||||||
|
isinstance(v, str) and query.lower() in v.lower()
|
||||||
|
for v in payload["metadata"].values()
|
||||||
|
)):
|
||||||
|
results.append({
|
||||||
|
"key": key,
|
||||||
|
"content": payload["content"],
|
||||||
|
"metadata": payload["metadata"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"results": results, "count": len(results)}
|
||||||
|
|
||||||
|
except redis.RedisError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Error searching Redis: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
fastapi>=0.68.0
|
fastapi>=0.68.0
|
||||||
uvicorn>=0.15.0
|
uvicorn>=0.15.0
|
||||||
pydantic>=1.8.0
|
pydantic>=1.8.0
|
||||||
|
redis>=4.0.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user