diff --git a/main.py b/main.py index 546286a..3710ea9 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,29 @@ from pydantic import BaseModel from typing import Any import uvicorn import json +import redis +import os + +# Initialize Redis connection +def get_redis_client(): + redis_host = os.getenv('REDIS_HOST', 'localhost') + redis_port = int(os.getenv('REDIS_PORT', 6379)) + redis_db = int(os.getenv('REDIS_DB', 0)) + + try: + client = redis.StrictRedis( + host=redis_host, + port=redis_port, + db=redis_db, + decode_responses=True + ) + # Test connection + client.ping() + return client + except redis.ConnectionError as e: + raise HTTPException(status_code=503, detail=f"Could not connect to Redis: {str(e)}") + +redis_client = get_redis_client() app = FastAPI( title="Remember Service", @@ -17,20 +40,63 @@ class RememberRequest(BaseModel): content: Any metadata: dict = {} -@app.post("/remember", +@app.post("/remember", summary="Store information to remember", - description="Saves the provided content and optional metadata. This endpoint is designed to be called by an LLM as a tool.") + description="Saves the provided content and optional metadata in Redis.") def remember_endpoint(request: RememberRequest): """ Endpoint to remember information. - Prints the received payload to stdout in JSON format. + Stores the received payload in Redis and prints it to stdout in JSON format. """ payload = { "content": request.content, "metadata": request.metadata } + payload_json = json.dumps(payload, ensure_ascii=False) + + # Generate a unique key based on timestamp + import time + key = f"remember:{int(time.time())}" + + # Store in Redis + redis_client.set(key, payload_json) + print(json.dumps(payload, indent=2, ensure_ascii=False)) - return {"status": "remembered"} + return {"status": "remembered", "key": key} + +@app.get("/recall", + summary="Recall information previously remembered", + description="Searches for content in Redis based on query string. Returns matching entries.") +def recall_endpoint(query: str): + """ + Endpoint to recall information from Redis. + Searches for entries containing the query string in their content or metadata. + """ + try: + # Get all keys matching the remember pattern + keys = redis_client.keys("remember:*") + + results = [] + for key in keys: + data = redis_client.get(key) + if data: + payload = json.loads(data) + # Check if query is in content or metadata + if (isinstance(payload.get("content"), str) and query.lower() in payload["content"].lower()) or \ + (isinstance(payload.get("metadata"), dict) and any( + isinstance(v, str) and query.lower() in v.lower() + for v in payload["metadata"].values() + )): + results.append({ + "key": key, + "content": payload["content"], + "metadata": payload["metadata"] + }) + + return {"results": results, "count": len(results)} + + except redis.RedisError as e: + raise HTTPException(status_code=500, detail=f"Error searching Redis: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/requirements.txt b/requirements.txt index 6a42f7b..af2a7c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ fastapi>=0.68.0 uvicorn>=0.15.0 pydantic>=1.8.0 +redis>=4.0.0