Refactored ZMQ context implementation #16
@@ -1,11 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import zmq
|
||||||
from spade.agent import Agent
|
from spade.agent import Agent
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
import zmq
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.core.zmq_context import context
|
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -55,6 +56,8 @@ class RICommandAgent(Agent):
|
|||||||
"""
|
"""
|
||||||
logger.info("Setting up %s", self.jid)
|
logger.info("Setting up %s", self.jid)
|
||||||
|
|
||||||
|
context = Context.instance()
|
||||||
|
|
||||||
# To the robot
|
# To the robot
|
||||||
self.pubsocket = context.socket(zmq.PUB)
|
self.pubsocket = context.socket(zmq.PUB)
|
||||||
if self.bind:
|
if self.bind:
|
||||||
@@ -64,7 +67,7 @@ class RICommandAgent(Agent):
|
|||||||
|
|
||||||
# Receive internal topics regarding commands
|
# Receive internal topics regarding commands
|
||||||
self.subsocket = context.socket(zmq.SUB)
|
self.subsocket = context.socket(zmq.SUB)
|
||||||
self.subsocket.connect(settings.zmq_settings.internal_comm_address)
|
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||||
|
|
||||||
# Add behaviour to our agent
|
# Add behaviour to our agent
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import zmq
|
||||||
from spade.agent import Agent
|
from spade.agent import Agent
|
||||||
from spade.behaviour import CyclicBehaviour
|
from spade.behaviour import CyclicBehaviour
|
||||||
import zmq
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.core.config import settings
|
|
||||||
from control_backend.core.zmq_context import context
|
|
||||||
from control_backend.schemas.message import Message
|
|
||||||
from control_backend.agents.ri_command_agent import RICommandAgent
|
from control_backend.agents.ri_command_agent import RICommandAgent
|
||||||
|
from control_backend.core.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ class RICommunicationAgent(Agent):
|
|||||||
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
|
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
|
||||||
|
|
||||||
# We didnt get a reply :(
|
# We didnt get a reply :(
|
||||||
except asyncio.TimeoutError as e:
|
except TimeoutError:
|
||||||
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
logger.info("No ping retrieved in 3 seconds, killing myself.")
|
||||||
self.kill()
|
self.kill()
|
||||||
|
|
||||||
@@ -75,7 +74,7 @@ class RICommunicationAgent(Agent):
|
|||||||
# Let's try a certain amount of times before failing connection
|
# Let's try a certain amount of times before failing connection
|
||||||
while retries < max_retries:
|
while retries < max_retries:
|
||||||
# Bind request socket
|
# Bind request socket
|
||||||
self.req_socket = context.socket(zmq.REQ)
|
self.req_socket = Context.instance().socket(zmq.REQ)
|
||||||
if self._bind:
|
if self._bind:
|
||||||
self.req_socket.bind(self._address)
|
self.req_socket.bind(self._address)
|
||||||
else:
|
else:
|
||||||
@@ -88,7 +87,7 @@ class RICommunicationAgent(Agent):
|
|||||||
try:
|
try:
|
||||||
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
|
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except TimeoutError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No connection established in 20 seconds (attempt %d/%d)",
|
"No connection established in 20 seconds (attempt %d/%d)",
|
||||||
retries + 1,
|
retries + 1,
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from spade.message import Message
|
|||||||
|
|
||||||
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
|
from control_backend.agents.transcription.speech_recognizer import SpeechRecognizer
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.core.zmq_context import context as zmq_context
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -47,7 +46,8 @@ class TranscriptionAgent(Agent):
|
|||||||
"""Share a transcription to the other agents that depend on it."""
|
"""Share a transcription to the other agents that depend on it."""
|
||||||
receiver_jids = [
|
receiver_jids = [
|
||||||
settings.agent_settings.text_belief_extractor_agent_name
|
settings.agent_settings.text_belief_extractor_agent_name
|
||||||
+ '@' + settings.agent_settings.host,
|
+ "@"
|
||||||
|
+ settings.agent_settings.host,
|
||||||
] # Set message receivers here
|
] # Set message receivers here
|
||||||
|
|
||||||
for receiver_jid in receiver_jids:
|
for receiver_jid in receiver_jids:
|
||||||
@@ -68,7 +68,7 @@ class TranscriptionAgent(Agent):
|
|||||||
return await super().stop()
|
return await super().stop()
|
||||||
|
|
||||||
def _connect_audio_in_socket(self):
|
def _connect_audio_in_socket(self):
|
||||||
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
self.audio_in_socket.connect(self.audio_in_address)
|
self.audio_in_socket.connect(self.audio_in_address)
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from spade.behaviour import CyclicBehaviour
|
|||||||
|
|
||||||
from control_backend.agents.transcription import TranscriptionAgent
|
from control_backend.agents.transcription import TranscriptionAgent
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.core.zmq_context import context as zmq_context
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -121,7 +120,7 @@ class VADAgent(Agent):
|
|||||||
return await super().stop()
|
return await super().stop()
|
||||||
|
|
||||||
def _connect_audio_in_socket(self):
|
def _connect_audio_in_socket(self):
|
||||||
self.audio_in_socket = zmq_context.socket(zmq.SUB)
|
self.audio_in_socket = azmq.Context.instance().socket(zmq.SUB)
|
||||||
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
self.audio_in_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
if self.audio_in_bind:
|
if self.audio_in_bind:
|
||||||
self.audio_in_socket.bind(self.audio_in_address)
|
self.audio_in_socket.bind(self.audio_in_address)
|
||||||
@@ -132,7 +131,7 @@ class VADAgent(Agent):
|
|||||||
def _connect_audio_out_socket(self) -> int | None:
|
def _connect_audio_out_socket(self) -> int | None:
|
||||||
"""Returns the port bound, or None if binding failed."""
|
"""Returns the port bound, or None if binding failed."""
|
||||||
try:
|
try:
|
||||||
self.audio_out_socket = zmq_context.socket(zmq.PUB)
|
self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB)
|
||||||
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
|
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
|
||||||
except zmq.ZMQBindError:
|
except zmq.ZMQBindError:
|
||||||
logger.error("Failed to bind an audio output socket after 100 tries.")
|
logger.error("Failed to bind an audio output socket after 100 tries.")
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
from fastapi import APIRouter, Request
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from zmq import Socket
|
import zmq
|
||||||
|
from fastapi import APIRouter, Request
|
||||||
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
from control_backend.schemas.ri_message import SpeechCommand, RIEndpoint
|
from control_backend.core.config import settings
|
||||||
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,8 +17,8 @@ async def receive_command(command: SpeechCommand, request: Request):
|
|||||||
# Validate and retrieve data.
|
# Validate and retrieve data.
|
||||||
SpeechCommand.model_validate(command)
|
SpeechCommand.model_validate(command)
|
||||||
topic = b"command"
|
topic = b"command"
|
||||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
pub_socket = Context.instance().socket(zmq.PUB)
|
||||||
pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
||||||
|
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
|
||||||
|
|
||||||
return {"status": "Command received"}
|
return {"status": "Command received"}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import zmq
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from zmq import Socket
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
|
from control_backend.core.config import settings
|
||||||
from control_backend.schemas.message import Message
|
from control_backend.schemas.message import Message
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -17,8 +19,8 @@ async def receive_message(message: Message, request: Request):
|
|||||||
topic = b"message"
|
topic = b"message"
|
||||||
body = message.model_dump_json().encode("utf-8")
|
body = message.model_dump_json().encode("utf-8")
|
||||||
|
|
||||||
pub_socket: Socket = request.app.state.internal_comm_socket
|
pub_socket = Context.instance().socket(zmq.PUB)
|
||||||
|
pub_socket.bind(settings.zmq_settings.internal_pub_address)
|
||||||
pub_socket.send_multipart([topic, body])
|
await pub_socket.send_multipart([topic, body])
|
||||||
|
|
||||||
return {"status": "Message received"}
|
return {"status": "Message received"}
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class ZMQSettings(BaseModel):
|
class ZMQSettings(BaseModel):
|
||||||
internal_comm_address: str = "tcp://localhost:5560"
|
internal_pub_address: str = "tcp://localhost:5560"
|
||||||
|
internal_sub_address: str = "tcp://localhost:5561"
|
||||||
|
|
||||||
|
|
||||||
class AgentSettings(BaseModel):
|
class AgentSettings(BaseModel):
|
||||||
@@ -24,6 +25,7 @@ class LLMSettings(BaseModel):
|
|||||||
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
local_llm_url: str = "http://localhost:1234/v1/chat/completions"
|
||||||
local_llm_model: str = "openai/gpt-oss-20b"
|
local_llm_model: str = "openai/gpt-oss-20b"
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
app_title: str = "PepperPlus"
|
app_title: str = "PepperPlus"
|
||||||
|
|
||||||
@@ -37,4 +39,5 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
from zmq.asyncio import Context
|
|
||||||
|
|
||||||
context = Context()
|
|
||||||
@@ -7,17 +7,18 @@ import logging
|
|||||||
import zmq
|
import zmq
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from zmq.asyncio import Context
|
||||||
|
|
||||||
|
from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
||||||
|
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
|
||||||
|
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
|
||||||
|
from control_backend.agents.llm.llm import LLMAgent
|
||||||
|
|
||||||
# Internal imports
|
# Internal imports
|
||||||
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
from control_backend.agents.ri_communication_agent import RICommunicationAgent
|
||||||
from control_backend.agents.bdi.bdi_core import BDICoreAgent
|
|
||||||
from control_backend.agents.vad_agent import VADAgent
|
from control_backend.agents.vad_agent import VADAgent
|
||||||
from control_backend.agents.llm.llm import LLMAgent
|
|
||||||
from control_backend.agents.bdi.text_extractor import TBeliefExtractor
|
|
||||||
from control_backend.agents.belief_collector.belief_collector import BeliefCollectorAgent
|
|
||||||
from control_backend.api.v1.router import api_router
|
from control_backend.api.v1.router import api_router
|
||||||
from control_backend.core.config import settings
|
from control_backend.core.config import settings
|
||||||
from control_backend.core.zmq_context import context
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
@@ -28,12 +29,17 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.info("%s starting up.", app.title)
|
logger.info("%s starting up.", app.title)
|
||||||
|
|
||||||
# Initiate sockets
|
# Initiate sockets
|
||||||
internal_comm_socket = context.socket(zmq.PUB)
|
context = Context.instance()
|
||||||
internal_comm_address = settings.zmq_settings.internal_comm_address
|
|
||||||
internal_comm_socket.bind(internal_comm_address)
|
|
||||||
app.state.internal_comm_socket = internal_comm_socket
|
|
||||||
logger.info("Internal publishing socket bound to %s", internal_comm_socket)
|
|
||||||
|
|
||||||
|
internal_pub_socket = context.socket(zmq.XPUB)
|
||||||
|
internal_pub_socket.bind(settings.zmq_settings.internal_pub_address)
|
||||||
|
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
|
||||||
|
|
||||||
|
internal_sub_socket = context.socket(zmq.XSUB)
|
||||||
|
internal_sub_socket.bind(settings.zmq_settings.internal_sub_address)
|
||||||
|
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
|
||||||
|
|
||||||
|
zmq.proxy(internal_pub_socket, internal_sub_socket)
|
||||||
|
|
||||||
# Initiate agents
|
# Initiate agents
|
||||||
ri_communication_agent = RICommunicationAgent(
|
ri_communication_agent = RICommunicationAgent(
|
||||||
@@ -45,26 +51,28 @@ async def lifespan(app: FastAPI):
|
|||||||
await ri_communication_agent.start()
|
await ri_communication_agent.start()
|
||||||
|
|
||||||
llm_agent = LLMAgent(
|
llm_agent = LLMAgent(
|
||||||
settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host,
|
settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
|
||||||
settings.agent_settings.llm_agent_name,
|
settings.agent_settings.llm_agent_name,
|
||||||
)
|
)
|
||||||
await llm_agent.start()
|
await llm_agent.start()
|
||||||
|
|
||||||
bdi_core = BDICoreAgent(
|
bdi_core = BDICoreAgent(
|
||||||
settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
|
settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
|
||||||
settings.agent_settings.bdi_core_agent_name,
|
settings.agent_settings.bdi_core_agent_name,
|
||||||
"src/control_backend/agents/bdi/rules.asl",
|
"src/control_backend/agents/bdi/rules.asl",
|
||||||
)
|
)
|
||||||
await bdi_core.start()
|
await bdi_core.start()
|
||||||
|
|
||||||
belief_collector = BeliefCollectorAgent(
|
belief_collector = BeliefCollectorAgent(
|
||||||
settings.agent_settings.belief_collector_agent_name + '@' + settings.agent_settings.host,
|
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host,
|
||||||
settings.agent_settings.belief_collector_agent_name,
|
settings.agent_settings.belief_collector_agent_name,
|
||||||
)
|
)
|
||||||
await belief_collector.start()
|
await belief_collector.start()
|
||||||
|
|
||||||
text_belief_extractor = TBeliefExtractor(
|
text_belief_extractor = TBeliefExtractor(
|
||||||
settings.agent_settings.text_belief_extractor_agent_name + '@' + settings.agent_settings.host,
|
settings.agent_settings.text_belief_extractor_agent_name
|
||||||
|
+ "@"
|
||||||
|
+ settings.agent_settings.host,
|
||||||
settings.agent_settings.text_belief_extractor_agent_name,
|
settings.agent_settings.text_belief_extractor_agent_name,
|
||||||
)
|
)
|
||||||
await text_belief_extractor.start()
|
await text_belief_extractor.start()
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from control_backend.api.v1.endpoints import command
|
from control_backend.api.v1.endpoints import command
|
||||||
from control_backend.schemas.ri_message import SpeechCommand
|
from control_backend.schemas.ri_message import SpeechCommand
|
||||||
@@ -15,7 +16,6 @@ def app():
|
|||||||
"""
|
"""
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(command.router)
|
app.include_router(command.router)
|
||||||
app.state.internal_comm_socket = MagicMock() # mock ZMQ socket
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@@ -25,12 +25,42 @@ def client(app):
|
|||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_receive_command_endpoint(client, app):
|
@pytest.mark.asyncio
|
||||||
|
@patch("control_backend.api.endpoints.command.Context.instance")
|
||||||
|
async def test_receive_command_success(mock_context_instance, async_client):
|
||||||
|
"""
|
||||||
|
Test for successful reception of a command.
|
||||||
|
Ensures the status code is 202 and the response body is correct.
|
||||||
|
It also verifies that the ZeroMQ socket's send_multipart method is called with the expected data.
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
mock_pub_socket = AsyncMock()
|
||||||
|
mock_context_instance.return_value.socket.return_value = mock_pub_socket
|
||||||
|
|
||||||
|
command_data = {"command": "test_command", "text": "This is a test"}
|
||||||
|
speech_command = SpeechCommand(**command_data)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = await async_client.post("/command", json=command_data)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 202
|
||||||
|
assert response.json() == {"status": "Command received"}
|
||||||
|
|
||||||
|
# Verify that the ZMQ socket was used correctly
|
||||||
|
mock_context_instance.return_value.socket.assert_called_once_with(1) # zmq.PUB is 1
|
||||||
|
mock_pub_socket.connect.assert_called_once()
|
||||||
|
mock_pub_socket.send_multipart.assert_awaited_once_with(
|
||||||
|
[b"command", speech_command.model_dump_json().encode()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_receive_command_endpoint(client, app, mocker):
|
||||||
"""
|
"""
|
||||||
Test that a POST to /command sends the right multipart message
|
Test that a POST to /command sends the right multipart message
|
||||||
and returns a 202 with the expected JSON body.
|
and returns a 202 with the expected JSON body.
|
||||||
"""
|
"""
|
||||||
mock_socket = app.state.internal_comm_socket
|
mock_socket = mocker.patch.object()
|
||||||
|
|
||||||
# Prepare test payload that matches SpeechCommand
|
# Prepare test payload that matches SpeechCommand
|
||||||
payload = {"endpoint": "actuate/speech", "data": "yooo"}
|
payload = {"endpoint": "actuate/speech", "data": "yooo"}
|
||||||
|
|||||||
Reference in New Issue
Block a user