Add /recall endpoint, add Redis support
This commit is contained in:
		
							parent
							
								
									3efaddfae1
								
							
						
					
					
						commit
						b7c26536d1
					
				
							
								
								
									
										74
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										74
									
								
								main.py
									
									
									
									
									
								
							| @ -3,6 +3,29 @@ from pydantic import BaseModel | |||||||
| from typing import Any | from typing import Any | ||||||
| import uvicorn | import uvicorn | ||||||
| import json | import json | ||||||
|  | import redis | ||||||
|  | import os | ||||||
|  | 
 | ||||||
|  | # Initialize Redis connection | ||||||
|  | def get_redis_client(): | ||||||
|  |     redis_host = os.getenv('REDIS_HOST', 'localhost') | ||||||
|  |     redis_port = int(os.getenv('REDIS_PORT', 6379)) | ||||||
|  |     redis_db = int(os.getenv('REDIS_DB', 0)) | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         client = redis.StrictRedis( | ||||||
|  |             host=redis_host, | ||||||
|  |             port=redis_port, | ||||||
|  |             db=redis_db, | ||||||
|  |             decode_responses=True | ||||||
|  |         ) | ||||||
|  |         # Test connection | ||||||
|  |         client.ping() | ||||||
|  |         return client | ||||||
|  |     except redis.ConnectionError as e: | ||||||
|  |         raise HTTPException(status_code=503, detail=f"Could not connect to Redis: {str(e)}") | ||||||
|  | 
 | ||||||
|  | redis_client = get_redis_client() | ||||||
| 
 | 
 | ||||||
| app = FastAPI( | app = FastAPI( | ||||||
|     title="Remember Service", |     title="Remember Service", | ||||||
| @ -17,20 +40,63 @@ class RememberRequest(BaseModel): | |||||||
|     content: Any |     content: Any | ||||||
|     metadata: dict = {} |     metadata: dict = {} | ||||||
| 
 | 
 | ||||||
| @app.post("/remember",  | @app.post("/remember", | ||||||
|           summary="Store information to remember", |           summary="Store information to remember", | ||||||
|           description="Saves the provided content and optional metadata. This endpoint is designed to be called by an LLM as a tool.") |           description="Saves the provided content and optional metadata in Redis.") | ||||||
| def remember_endpoint(request: RememberRequest): | def remember_endpoint(request: RememberRequest): | ||||||
|     """ |     """ | ||||||
|     Endpoint to remember information. |     Endpoint to remember information. | ||||||
|     Prints the received payload to stdout in JSON format. |     Stores the received payload in Redis and prints it to stdout in JSON format. | ||||||
|     """ |     """ | ||||||
|     payload = { |     payload = { | ||||||
|         "content": request.content, |         "content": request.content, | ||||||
|         "metadata": request.metadata |         "metadata": request.metadata | ||||||
|     } |     } | ||||||
|  |     payload_json = json.dumps(payload, ensure_ascii=False) | ||||||
|  | 
 | ||||||
|  |     # Generate a unique key based on timestamp | ||||||
|  |     import time | ||||||
|  |     key = f"remember:{int(time.time())}" | ||||||
|  | 
 | ||||||
|  |     # Store in Redis | ||||||
|  |     redis_client.set(key, payload_json) | ||||||
|  | 
 | ||||||
|     print(json.dumps(payload, indent=2, ensure_ascii=False)) |     print(json.dumps(payload, indent=2, ensure_ascii=False)) | ||||||
|     return {"status": "remembered"} |     return {"status": "remembered", "key": key} | ||||||
|  | 
 | ||||||
|  | @app.get("/recall", | ||||||
|  |          summary="Recall information previously remembered", | ||||||
|  |          description="Searches for content in Redis based on query string. Returns matching entries.") | ||||||
|  | def recall_endpoint(query: str): | ||||||
|  |     """ | ||||||
|  |     Endpoint to recall information from Redis. | ||||||
|  |     Searches for entries containing the query string in their content or metadata. | ||||||
|  |     """ | ||||||
|  |     try: | ||||||
|  |         # Get all keys matching the remember pattern | ||||||
|  |         keys = redis_client.keys("remember:*") | ||||||
|  | 
 | ||||||
|  |         results = [] | ||||||
|  |         for key in keys: | ||||||
|  |             data = redis_client.get(key) | ||||||
|  |             if data: | ||||||
|  |                 payload = json.loads(data) | ||||||
|  |                 # Check if query is in content or metadata | ||||||
|  |                 if (isinstance(payload.get("content"), str) and query.lower() in payload["content"].lower()) or \ | ||||||
|  |                    (isinstance(payload.get("metadata"), dict) and any( | ||||||
|  |                        isinstance(v, str) and query.lower() in v.lower() | ||||||
|  |                        for v in payload["metadata"].values() | ||||||
|  |                    )): | ||||||
|  |                     results.append({ | ||||||
|  |                         "key": key, | ||||||
|  |                         "content": payload["content"], | ||||||
|  |                         "metadata": payload["metadata"] | ||||||
|  |                     }) | ||||||
|  | 
 | ||||||
|  |         return {"results": results, "count": len(results)} | ||||||
|  | 
 | ||||||
|  |     except redis.RedisError as e: | ||||||
|  |         raise HTTPException(status_code=500, detail=f"Error searching Redis: {str(e)}") | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     uvicorn.run(app, host="0.0.0.0", port=8000) |     uvicorn.run(app, host="0.0.0.0", port=8000) | ||||||
|  | |||||||
| @ -1,3 +1,4 @@ | |||||||
| fastapi>=0.68.0 | fastapi>=0.68.0 | ||||||
| uvicorn>=0.15.0 | uvicorn>=0.15.0 | ||||||
| pydantic>=1.8.0 | pydantic>=1.8.0 | ||||||
|  | redis>=4.0.0 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user