diff --git a/src/control_backend/agents/test_agent.py b/src/control_backend/agents/test_agent.py index 7a9707b..749c96b 100644 --- a/src/control_backend/agents/test_agent.py +++ b/src/control_backend/agents/test_agent.py @@ -1,4 +1,37 @@ +import json +import logging from spade.agent import Agent +from spade.behaviour import CyclicBehaviour +import zmq + +from control_backend.core.config import settings +from control_backend.core.zmq_context import context +from control_backend.schemas.message import Message + +logger = logging.getLogger(__name__) class TestAgent(Agent): - pass \ No newline at end of file + socket: zmq.Socket + + class ListenBehaviour(CyclicBehaviour): + async def run(self): + assert self.agent is not None + topic, body = await self.agent.socket.recv_multipart() + + try: + message_json = json.loads(body.decode("utf-8")) + message = Message.model_validate(message_json) + logger.info("Received message \"%s\"", message.message) + except Exception as e: + logger.error("Error processing message: %s", e) + + async def setup(self): + logger.info("Setting up %s", self.jid) + self.socket = context.socket(zmq.SUB) + self.socket.connect(settings.zmq_settings.internal_comm_address) + self.socket.setsockopt(zmq.SUBSCRIBE, b"message") + + b = self.ListenBehaviour() + self.add_behaviour(b) + + logger.info("Finished setting up %s", self.jid) diff --git a/src/control_backend/api/v1/endpoints/message.py b/src/control_backend/api/v1/endpoints/message.py index 1ad0b65..fef07b8 100644 --- a/src/control_backend/api/v1/endpoints/message.py +++ b/src/control_backend/api/v1/endpoints/message.py @@ -1,13 +1,23 @@ from fastapi import APIRouter, Request import logging +from zmq import Socket + from control_backend.schemas.message import Message logger = logging.getLogger(__name__) router = APIRouter() -# TODO: implement -@router.post("/message") +@router.post("/message", status_code=202) async def receive_message(message: Message, request: Request): logger.info("Received message: %s", message.message) + + topic = b"message" + body = message.model_dump_json().encode("utf-8") + + pub_socket: Socket = request.app.state.internal_comm_socket + + pub_socket.send_multipart([topic, body]) + + return {"status": "Message received"} diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 8d91af5..fca21b3 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -1,11 +1,16 @@ -from pydantic import HttpUrl +from pydantic import BaseModel from pydantic_settings import BaseSettings, SettingsConfigDict +class ZMQSettings(BaseModel): + internal_comm_address: str = "tcp://localhost:5560" + class Settings(BaseSettings): app_title: str = "PepperPlus" ui_url: str = "http://localhost:5173" + + zmq_settings: ZMQSettings = ZMQSettings() model_config = SettingsConfigDict(env_file=".env") -settings = Settings() \ No newline at end of file +settings = Settings() diff --git a/src/control_backend/core/zmq_context.py b/src/control_backend/core/zmq_context.py new file mode 100644 index 0000000..a74544f --- /dev/null +++ b/src/control_backend/core/zmq_context.py @@ -0,0 +1,3 @@ +from zmq.asyncio import Context + +context = Context() diff --git a/src/control_backend/main.py b/src/control_backend/main.py index 8fa0428..cd4d3fa 100644 --- a/src/control_backend/main.py +++ b/src/control_backend/main.py @@ -3,10 +3,13 @@ import contextlib from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import logging +import zmq # Internal imports +from control_backend.agents.test_agent import TestAgent from control_backend.api.v1.router import api_router from control_backend.core.config import settings +from control_backend.core.zmq_context import context logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -14,6 +17,17 @@ logging.basicConfig(level=logging.INFO) @contextlib.asynccontextmanager async def lifespan(app: FastAPI): logger.info("%s starting up.", app.title) + + # Initiate sockets + internal_comm_socket = context.socket(zmq.PUB) + internal_comm_address = settings.zmq_settings.internal_comm_address + internal_comm_socket.bind(internal_comm_address) + app.state.internal_comm_socket = internal_comm_socket + logger.info("Internal publishing socket bound to %s", internal_comm_socket) + + # Initiate agents + test_agent = TestAgent("test_agent@localhost", "test_agent") + await test_agent.start() yield @@ -34,4 +48,4 @@ app.include_router(api_router, prefix="") # TODO: make prefix /api/v1 @app.get("/") async def root(): - return {"status": "ok"} \ No newline at end of file + return {"status": "ok"}