Refactoring agent and behaviour naming and structure. #25

Merged
9828273 merged 9 commits from refactor/agent-naming into dev 2025-11-19 15:37:14 +00:00
37 changed files with 355 additions and 376 deletions

View File

@@ -1,7 +1 @@
from .base import BaseAgent as BaseAgent from .base import BaseAgent as BaseAgent
from .belief_collector.belief_collector import BeliefCollectorAgent as BeliefCollectorAgent
from .llm.llm import LLMAgent as LLMAgent
from .ri_command_agent import RICommandAgent as RICommandAgent
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent
from .transcription.transcription_agent import TranscriptionAgent as TranscriptionAgent
from .vad_agent import VADAgent as VADAgent

View File

@@ -0,0 +1 @@
from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent

View File

@@ -10,7 +10,7 @@ from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
class RICommandAgent(BaseAgent): class RobotSpeechAgent(BaseAgent):
subsocket: zmq.Socket subsocket: zmq.Socket
pubsocket: zmq.Socket pubsocket: zmq.Socket
address = "" address = ""
@@ -29,7 +29,7 @@ class RICommandAgent(BaseAgent):
self.address = address self.address = address
self.bind = bind self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour): class SendZMQCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from the UI.""" """Behaviour for sending commands received from the UI."""
async def run(self): async def run(self):
@@ -50,7 +50,7 @@ class RICommandAgent(BaseAgent):
except Exception as e: except Exception as e:
self.agent.logger.error("Error processing message: %s", e) self.agent.logger.error("Error processing message: %s", e)
class SendPythonCommandsBehaviour(CyclicBehaviour): class SendSpadeCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from other Python agents.""" """Behaviour for sending commands received from other Python agents."""
async def run(self): async def run(self):
@@ -64,7 +64,7 @@ class RICommandAgent(BaseAgent):
async def setup(self): async def setup(self):
""" """
Setup the command agent Setup the robot speech command agent
""" """
self.logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.jid)
@@ -83,8 +83,8 @@ class RICommandAgent(BaseAgent):
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent # Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour() commands_behaviour = self.SendZMQCommandsBehaviour()
self.add_behaviour(commands_behaviour) self.add_behaviour(commands_behaviour)
self.add_behaviour(self.SendPythonCommandsBehaviour()) self.add_behaviour(self.SendSpadeCommandsBehaviour())
self.logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,2 +1,7 @@
from .bdi_core import BDICoreAgent as BDICoreAgent from .bdi_core_agent.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .text_extractor import TBeliefExtractorAgent as TBeliefExtractorAgent from .belief_collector_agent.belief_collector_agent import (
BDIBeliefCollectorAgent as BDIBeliefCollectorAgent,
)
from .text_belief_extractor_agent.text_belief_extractor_agent import (
TextBeliefExtractorAgent as TextBeliefExtractorAgent,
)

View File

@@ -7,7 +7,7 @@ from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .behaviours.belief_setter import BeliefSetterBehaviour from .behaviours.belief_setter_behaviour import BeliefSetterBehaviour
from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour
@@ -57,7 +57,7 @@ class BDICoreAgent(BDIAgent):
class SendBehaviour(OneShotBehaviour): class SendBehaviour(OneShotBehaviour):
async def run(self) -> None: async def run(self) -> None:
msg = Message( msg = Message(
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, to=settings.agent_settings.llm_name + "@" + settings.agent_settings.host,
body=text, body=text,
) )

View File

@@ -32,7 +32,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.agent.logger.debug("Processing message from sender: %s", sender) self.agent.logger.debug("Processing message from sender: %s", sender)
match sender: match sender:
case settings.agent_settings.belief_collector_agent_name: case settings.agent_settings.bdi_belief_collector_name:
self.agent.logger.debug( self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message." "Message is from the belief collector agent. Processing as belief message."
) )

View File

@@ -15,14 +15,14 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:
case settings.agent_settings.llm_agent_name: case settings.agent_settings.llm_name:
content = msg.body content = msg.body
self.agent.logger.info("Received LLM response: %s", content) self.agent.logger.info("Received LLM response: %s", content)
speech_command = SpeechCommand(data=content) speech_command = SpeechCommand(data=content)
message = Message( message = Message(
to=settings.agent_settings.ri_command_agent_name to=settings.agent_settings.robot_speech_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
sender=self.agent.jid, sender=self.agent.jid,

View File

@@ -7,7 +7,7 @@ from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings from control_backend.core.config import settings
class ContinuousBeliefCollector(CyclicBehaviour): class BeliefCollectorBehaviour(CyclicBehaviour):
""" """
Continuously collects beliefs/emotions from extractor agents: Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent. Then we send a unified belief packet to the BDI agent.
@@ -35,7 +35,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg_type = payload.get("type") msg_type = payload.get("type")
# Prefer explicit 'type' field # Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": if msg_type == "belief_extraction_text" or sender_node == "bel_text_agent_mock":
self.agent.logger.debug( self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node "Message routed to _handle_belief_text (sender=%s)", sender_node
) )
@@ -83,7 +83,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
if not beliefs: if not beliefs:
return return
to_jid = f"{settings.agent_settings.bdi_core_agent_name}@{settings.agent_settings.host}" to_jid = f"{settings.agent_settings.bdi_core_name}@{settings.agent_settings.host}"
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs) msg.body = json.dumps(beliefs)

View File

@@ -0,0 +1,11 @@
from control_backend.agents.base import BaseAgent
from .behaviours.belief_collector_behaviour import BeliefCollectorBehaviour
class BDIBeliefCollectorAgent(BaseAgent):
async def setup(self):
self.logger.info("BDIBeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(BeliefCollectorBehaviour())
self.logger.info("BDIBeliefCollectorAgent ready.")

View File

@@ -7,7 +7,7 @@ from spade.message import Message
from control_backend.core.config import settings from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour): class TextBeliefExtractorBehaviour(CyclicBehaviour):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: LLM prompt nog hardcoded # TODO: LLM prompt nog hardcoded
@@ -44,7 +44,7 @@ class BeliefFromText(CyclicBehaviour):
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:
case settings.agent_settings.transcription_agent_name: case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body) self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body) await self._process_transcription_demo(msg.body)
case _: case _:
@@ -71,7 +71,7 @@ class BeliefFromText(CyclicBehaviour):
belief_message = Message() belief_message = Message()
belief_message.to = ( belief_message.to = (
settings.agent_settings.belief_collector_agent_name settings.agent_settings.bdi_belief_collector_name
+ "@" + "@"
+ settings.agent_settings.host + settings.agent_settings.host
) )
@@ -95,7 +95,7 @@ class BeliefFromText(CyclicBehaviour):
belief_msg = Message() belief_msg = Message()
belief_msg.to = ( belief_msg.to = (
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host settings.agent_settings.bdi_belief_collector_name + "@" + settings.agent_settings.host
) )
belief_msg.body = payload belief_msg.body = payload
belief_msg.thread = "beliefs" belief_msg.thread = "beliefs"

View File

@@ -0,0 +1,8 @@
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor_behaviour import TextBeliefExtractorBehaviour
class TextBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(TextBeliefExtractorBehaviour())

View File

@@ -1,8 +0,0 @@
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(BeliefFromText())

View File

@@ -1,11 +0,0 @@
from control_backend.agents.base import BaseAgent
from .behaviours.continuous_collect import ContinuousBeliefCollector
class BeliefCollectorAgent(BaseAgent):
async def setup(self):
self.logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector())
self.logger.info("BeliefCollectorAgent ready.")

View File

@@ -0,0 +1 @@
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent

View File

@@ -8,7 +8,7 @@ from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .ri_command_agent import RICommandAgent from ..actuation.robot_speech_agent import RobotSpeechAgent
class RICommunicationAgent(BaseAgent): class RICommunicationAgent(BaseAgent):
@@ -87,8 +87,7 @@ class RICommunicationAgent(BaseAgent):
) )
except TimeoutError: except TimeoutError:
self.agent.logger.warning( self.agent.logger.warning(
"Initial connection ping for router timed" "Initial connection ping for router timed out in com_ri_agent."
" out in ri_communication_agent."
) )
# Try to reboot. # Try to reboot.
@@ -205,11 +204,11 @@ class RICommunicationAgent(BaseAgent):
else: else:
self._req_socket.bind(addr) self._req_socket.bind(addr)
case "actuation": case "actuation":
ri_commands_agent = RICommandAgent( ri_commands_agent = RobotSpeechAgent(
settings.agent_settings.ri_command_agent_name settings.agent_settings.robot_speech_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name, settings.agent_settings.robot_speech_name,
address=addr, address=addr,
bind=bind, bind=bind,
) )
@@ -243,9 +242,7 @@ class RICommunicationAgent(BaseAgent):
try: try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5) await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError: except TimeoutError:
self.logger.warning( self.logger.warning("Initial connection ping for router timed out in com_ri_agent.")
"Initial connection ping for router timed out in ri_communication_agent."
)
# Make sure to start listening now that we're connected. # Make sure to start listening now that we're connected.
self.connected = True self.connected = True

View File

@@ -0,0 +1 @@
from .llm_agent import LLMAgent as LLMAgent

View File

@@ -39,7 +39,7 @@ class LLMAgent(BaseAgent):
sender, sender,
) )
if sender == settings.agent_settings.bdi_core_agent_name: if sender == settings.agent_settings.bdi_core_name:
self.agent.logger.debug("Processing message from BDI Core Agent") self.agent.logger.debug("Processing message from BDI Core Agent")
await self._process_bdi_message(msg) await self._process_bdi_message(msg)
else: else:
@@ -63,7 +63,7 @@ class LLMAgent(BaseAgent):
Sends a response message back to the BDI Core Agent. Sends a response message back to the BDI Core Agent.
""" """
reply = Message( reply = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, to=settings.agent_settings.bdi_core_name + "@" + settings.agent_settings.host,
body=msg, body=msg,
) )
await self.send(reply) await self.send(reply)

View File

@@ -1,44 +0,0 @@
import json
from spade.agent import Agent
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self):
to_jid = (
settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host
)
# Send multiple beliefs in one JSON payload
payload = {
"type": "belief_extraction_text",
"beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
}
msg = Message(to=to_jid)
msg.body = json.dumps(payload)
await self.send(msg)
print(f"Beliefs sent to {to_jid}!")
self.exit_code = "Job Finished!"
await self.agent.stop()
async def setup(self):
print("BeliefTextAgent started")
self.b = self.SendOnceBehaviourBlfText()
self.add_behaviour(self.b)

View File

@@ -0,0 +1,4 @@
from .transcription_agent.transcription_agent import (
TranscriptionAgent as TranscriptionAgent,
)
from .vad_agent import VADAgent as VADAgent

View File

@@ -19,13 +19,13 @@ class TranscriptionAgent(BaseAgent):
""" """
def __init__(self, audio_in_address: str): def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host jid = settings.agent_settings.transcription_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.transcription_agent_name) super().__init__(jid, settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
class Transcribing(CyclicBehaviour): class TranscribingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket): def __init__(self, audio_in_socket: azmq.Socket):
super().__init__() super().__init__()
self.audio_in_socket = audio_in_socket self.audio_in_socket = audio_in_socket
@@ -43,7 +43,7 @@ class TranscriptionAgent(BaseAgent):
async def _share_transcription(self, transcription: str): async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it.""" """Share a transcription to the other agents that depend on it."""
receiver_jids = [ receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name settings.agent_settings.text_belief_extractor_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
] # Set message receivers here ] # Set message receivers here
@@ -79,7 +79,7 @@ class TranscriptionAgent(BaseAgent):
self._connect_audio_in_socket() self._connect_audio_in_socket()
transcribing = self.Transcribing(self.audio_in_socket) transcribing = self.TranscribingBehaviour(self.audio_in_socket)
transcribing.warmup() transcribing.warmup()
self.add_behaviour(transcribing) self.add_behaviour(transcribing)

View File

@@ -7,7 +7,7 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .transcription.transcription_agent import TranscriptionAgent from .transcription_agent.transcription_agent import TranscriptionAgent
class SocketPoller[T]: class SocketPoller[T]:
@@ -40,7 +40,7 @@ class SocketPoller[T]:
return None return None
class Streaming(CyclicBehaviour): class StreamingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket): def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__() super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket) self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
@@ -109,8 +109,8 @@ class VADAgent(BaseAgent):
""" """
def __init__(self, audio_in_address: str, audio_in_bind: bool): def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host jid = settings.agent_settings.vad_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name) super().__init__(jid, settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind self.audio_in_bind = audio_in_bind
@@ -118,7 +118,7 @@ class VADAgent(BaseAgent):
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
self.streaming_behaviour: Streaming | None = None self.streaming_behaviour: StreamingBehaviour | None = None
async def stop(self): async def stop(self):
""" """
@@ -162,7 +162,7 @@ class VADAgent(BaseAgent):
return return
audio_out_address = f"tcp://localhost:{audio_out_port}" audio_out_address = f"tcp://localhost:{audio_out_port}"
self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket) self.streaming_behaviour = StreamingBehaviour(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(self.streaming_behaviour) self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here

View File

@@ -9,16 +9,16 @@ class ZMQSettings(BaseModel):
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
host: str = "localhost" host: str = "localhost"
bdi_core_agent_name: str = "bdi_core" bdi_core_name: str = "bdi_core_agent"
belief_collector_agent_name: str = "belief_collector" bdi_belief_collector_name: str = "belief_collector_agent"
text_belief_extractor_agent_name: str = "text_belief_extractor" text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_agent_name: str = "vad_agent" vad_name: str = "vad_agent"
llm_agent_name: str = "llm_agent" llm_name: str = "llm_agent"
test_agent_name: str = "test_agent" test_name: str = "test_agent"
transcription_agent_name: str = "transcription_agent" transcription_name: str = "transcription_agent"
ri_communication_agent_name: str = "ri_communication_agent" ri_communication_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent" robot_speech_name: str = "robot_speech_agent"
class LLMSettings(BaseModel): class LLMSettings(BaseModel):

View File

@@ -7,13 +7,25 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context from zmq.asyncio import Context
from control_backend.agents import ( # Act agents
BeliefCollectorAgent, # BDI agents
LLMAgent, from control_backend.agents.bdi import (
RICommunicationAgent, BDIBeliefCollectorAgent,
VADAgent, BDICoreAgent,
TextBeliefExtractorAgent,
) )
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
# Communication agents
from control_backend.agents.communication import RICommunicationAgent
# Emotional Agents
# LLM Agents
from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# Other backend imports
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
@@ -67,10 +79,10 @@ async def lifespan(app: FastAPI):
"RICommunicationAgent": ( "RICommunicationAgent": (
RICommunicationAgent, RICommunicationAgent,
{ {
"name": settings.agent_settings.ri_communication_agent_name, "name": settings.agent_settings.ri_communication_name,
"jid": f"{settings.agent_settings.ri_communication_agent_name}" "jid": f"{settings.agent_settings.ri_communication_name}"
f"@{settings.agent_settings.host}", f"@{settings.agent_settings.host}",
"password": settings.agent_settings.ri_communication_agent_name, "password": settings.agent_settings.ri_communication_name,
"address": "tcp://*:5555", "address": "tcp://*:5555",
"bind": True, "bind": True,
}, },
@@ -78,37 +90,36 @@ async def lifespan(app: FastAPI):
"LLMAgent": ( "LLMAgent": (
LLMAgent, LLMAgent,
{ {
"name": settings.agent_settings.llm_agent_name, "name": settings.agent_settings.llm_name,
"jid": f"{settings.agent_settings.llm_agent_name}@{settings.agent_settings.host}", "jid": f"{settings.agent_settings.llm_name}@{settings.agent_settings.host}",
"password": settings.agent_settings.llm_agent_name, "password": settings.agent_settings.llm_name,
}, },
), ),
"BDICoreAgent": ( "BDICoreAgent": (
BDICoreAgent, BDICoreAgent,
{ {
"name": settings.agent_settings.bdi_core_agent_name, "name": settings.agent_settings.bdi_core_name,
"jid": f"{settings.agent_settings.bdi_core_agent_name}@" "jid": f"{settings.agent_settings.bdi_core_name}@{settings.agent_settings.host}",
f"{settings.agent_settings.host}", "password": settings.agent_settings.bdi_core_name,
"password": settings.agent_settings.bdi_core_agent_name, "asl": "src/control_backend/agents/bdi/bdi_core_agent/rules.asl",
"asl": "src/control_backend/agents/bdi/rules.asl",
}, },
), ),
"BeliefCollectorAgent": ( "BeliefCollectorAgent": (
BeliefCollectorAgent, BDIBeliefCollectorAgent,
{ {
"name": settings.agent_settings.belief_collector_agent_name, "name": settings.agent_settings.bdi_belief_collector_name,
"jid": f"{settings.agent_settings.belief_collector_agent_name}@" "jid": f"{settings.agent_settings.bdi_belief_collector_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.belief_collector_agent_name, "password": settings.agent_settings.bdi_belief_collector_name,
}, },
), ),
"TBeliefExtractor": ( "TextBeliefExtractorAgent": (
TBeliefExtractorAgent, TextBeliefExtractorAgent,
{ {
"name": settings.agent_settings.text_belief_extractor_agent_name, "name": settings.agent_settings.text_belief_extractor_name,
"jid": f"{settings.agent_settings.text_belief_extractor_agent_name}@" "jid": f"{settings.agent_settings.text_belief_extractor_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.text_belief_extractor_agent_name, "password": settings.agent_settings.text_belief_extractor_name,
}, },
), ),
"VADAgent": ( "VADAgent": (
@@ -117,17 +128,23 @@ async def lifespan(app: FastAPI):
), ),
} }
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items(): for name, (agent_class, kwargs) in agents_to_start.items():
try: try:
logger.debug("Starting agent: %s", name) logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"}) agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"})
await agent_instance.start() await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
logger.info("Agent '%s' started successfully.", name) logger.info("Agent '%s' started successfully.", name)
except Exception as e: except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True) logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
# Consider if the application should continue if an agent fails to start. # Consider if the application should continue if an agent fails to start.
raise raise
await vad_agent.streaming_behaviour.reset()
logger.info("Application startup complete.") logger.info("Application startup complete.")
yield yield

View File

@@ -4,12 +4,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import zmq import zmq
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.actuation.robot_speech_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@@ -19,8 +21,8 @@ async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True""" """Test setup with bind=True"""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup() await agent.setup()
@@ -28,7 +30,9 @@ async def test_setup_bind(zmq_context, mocker):
fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours)
# Ensure behaviour attached
assert any(isinstance(b, agent.SendZMQCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -36,8 +40,8 @@ async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup() await agent.setup()
@@ -55,14 +59,16 @@ async def test_send_commands_behaviour_valid_message():
) )
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = RobotSpeechAgent("test@server", "password")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendZMQCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand: with patch(
"control_backend.agents.actuation.robot_speech_agent.SpeechCommand"
) as MockSpeechCommand:
mock_message = MagicMock() mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message MockSpeechCommand.model_validate.return_value = mock_message
@@ -79,11 +85,11 @@ async def test_send_commands_behaviour_invalid_message():
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}")) fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = RobotSpeechAgent("test@server", "password")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendZMQCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
await behaviour.run() await behaviour.run()

View File

@@ -3,7 +3,11 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest import pytest
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def fake_json_correct_negototiate_1(): def fake_json_correct_negototiate_1():
@@ -86,7 +90,9 @@ def fake_json_invalid_id_negototiate():
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.communication.ri_communication_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@@ -102,10 +108,8 @@ async def test_setup_creates_socket_and_negotiate_1(zmq_context):
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
@@ -144,10 +148,8 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context):
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
@@ -186,13 +188,11 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context):
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
@@ -227,10 +227,8 @@ async def test_setup_creates_socket_and_negotiate_4(zmq_context):
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
@@ -268,10 +266,8 @@ async def test_setup_creates_socket_and_negotiate_5(zmq_context):
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
@@ -309,10 +305,8 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context):
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
@@ -350,13 +344,11 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context):
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
@@ -389,9 +381,7 @@ async def test_setup_creates_socket_and_negotiate_timeout(zmq_context):
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock() fake_socket.send_multipart = AsyncMock()
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
@@ -554,10 +544,8 @@ async def test_setup_unpacking_exception(zmq_context):
} # missing 'port' and 'bind' } # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch RICommandAgent so it won't actually start # Patch ActSpeechAgent so it won't actually start
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()

View File

@@ -5,43 +5,45 @@ import pytest
import zmq import zmq
from spade.agent import Agent from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@pytest.fixture @pytest.fixture
def streaming(mocker): def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming") return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
@pytest.fixture @pytest.fixture
def transcription_agent(mocker): def per_transcription_agent(mocker):
return mocker.patch("control_backend.agents.vad_agent.TranscriptionAgent", autospec=True) return mocker.patch(
"control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normal_setup(streaming, transcription_agent): async def test_normal_setup(streaming, per_transcription_agent):
""" """
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models. sockets, and starts the TranscriptionAgent without loading real models.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock() per_vad_agent.add_behaviour = MagicMock()
await vad_agent.setup() await per_vad_agent.setup()
streaming.assert_called_once() streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
transcription_agent.assert_called_once() per_transcription_agent.assert_called_once()
transcription_agent.return_value.start.assert_called_once() per_transcription_agent.return_value.start.assert_called_once()
assert vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False]) @pytest.mark.parametrize("do_bind", [True, False])
@@ -50,11 +52,11 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
Test that the VAD agent creates an audio input socket, differentiating between binding and Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting. connecting.
""" """
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind) per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
vad_agent._connect_audio_in_socket() per_vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
@@ -74,11 +76,11 @@ def test_out_socket_creation(zmq_context):
""" """
Test that the VAD agent creates an audio output socket correctly. Test that the VAD agent creates an audio output socket correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket() per_vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@@ -93,28 +95,28 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError zmq.ZMQBindError
) )
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup() await per_vad_agent.setup()
assert vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once() mock_super_stop.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop(zmq_context, transcription_agent): async def test_stop(zmq_context, per_transcription_agent):
""" """
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000, 1000,
10000, 10000,
) )
await vad_agent.setup() await per_vad_agent.setup()
await vad_agent.stop() await per_vad_agent.stop()
assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None assert per_vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None

View File

@@ -5,7 +5,7 @@ import pytest
import soundfile as sf import soundfile as sf
import zmq import zmq
from control_backend.agents.vad_agent import Streaming from control_backend.agents.perception.vad_agent import StreamingBehaviour
def get_audio_chunks() -> list[bytes]: def get_audio_chunks() -> list[bytes]:
@@ -42,12 +42,12 @@ async def test_real_audio(mocker):
audio_in_socket = AsyncMock() audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock() audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket) vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket)
vad_streamer._ready = True vad_streamer._ready = True
vad_streamer.agent = MagicMock() vad_streamer.agent = MagicMock()
for _ in audio_chunks: for _ in audio_chunks:

View File

@@ -4,10 +4,12 @@ from unittest.mock import AsyncMock, MagicMock, call
import pytest import pytest
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour from control_backend.agents.bdi.bdi_core_agent.behaviours.belief_setter_behaviour import (
BeliefSetterBehaviour,
)
# Define a constant for the collector agent name to use in tests # Define a constant for the collector agent name to use in tests
COLLECTOR_AGENT_NAME = "belief_collector" COLLECTOR_AGENT_NAME = "belief_collector_agent"
COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test" COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test"
@@ -21,11 +23,12 @@ def mock_agent(mocker):
@pytest.fixture @pytest.fixture
def belief_setter(mock_agent, mocker): def belief_setter_behaviour(mock_agent, mocker):
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent.""" """Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
# Patch the settings to use a predictable agent name # Patch the settings to use a predictable agent name
mocker.patch( mocker.patch(
"control_backend.agents.bdi.behaviours.belief_setter.settings.agent_settings.belief_collector_agent_name", "control_backend.agents.bdi.bdi_core_agent."
"behaviours.belief_setter_behaviour.settings.agent_settings.bdi_belief_collector_name",
COLLECTOR_AGENT_NAME, COLLECTOR_AGENT_NAME,
) )
@@ -46,53 +49,53 @@ def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_message_received(belief_setter, mocker): async def test_run_message_received(belief_setter_behaviour, mocker):
""" """
Test that when a message is received, _process_message is called. Test that when a message is received, _process_message is called.
""" """
# Arrange # Arrange
msg = MagicMock() msg = MagicMock()
belief_setter.receive.return_value = msg belief_setter_behaviour.receive.return_value = msg
mocker.patch.object(belief_setter, "_process_message") mocker.patch.object(belief_setter_behaviour, "_process_message")
# Act # Act
await belief_setter.run() await belief_setter_behaviour.run()
# Assert # Assert
belief_setter._process_message.assert_called_once_with(msg) belief_setter_behaviour._process_message.assert_called_once_with(msg)
def test_process_message_from_belief_collector(belief_setter, mocker): def test_process_message_from_bdi_belief_collector_agent(belief_setter_behaviour, mocker):
""" """
Test processing a message from the correct belief collector agent. Test processing a message from the correct belief collector agent.
""" """
# Arrange # Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="") msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter, "_process_belief_message") mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
# Act # Act
belief_setter._process_message(msg) belief_setter_behaviour._process_message(msg)
# Assert # Assert
mock_process_belief.assert_called_once_with(msg) mock_process_belief.assert_called_once_with(msg)
def test_process_message_from_other_agent(belief_setter, mocker): def test_process_message_from_other_agent(belief_setter_behaviour, mocker):
""" """
Test that messages from other agents are ignored. Test that messages from other agents are ignored.
""" """
# Arrange # Arrange
msg = create_mock_message(sender_node="other_agent", body="", thread="") msg = create_mock_message(sender_node="other_agent", body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter, "_process_belief_message") mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
# Act # Act
belief_setter._process_message(msg) belief_setter_behaviour._process_message(msg)
# Assert # Assert
mock_process_belief.assert_not_called() mock_process_belief.assert_not_called()
def test_process_belief_message_valid_json(belief_setter, mocker): def test_process_belief_message_valid_json(belief_setter_behaviour, mocker):
""" """
Test processing a valid belief message with correct thread and JSON body. Test processing a valid belief message with correct thread and JSON body.
""" """
@@ -101,16 +104,16 @@ def test_process_belief_message_valid_json(belief_setter, mocker):
msg = create_mock_message( msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs" sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs"
) )
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_called_once_with(beliefs_payload) mock_set_beliefs.assert_called_once_with(beliefs_payload)
def test_process_belief_message_invalid_json(belief_setter, mocker, caplog): def test_process_belief_message_invalid_json(belief_setter_behaviour, mocker, caplog):
""" """
Test that a message with invalid JSON is handled gracefully and an error is logged. Test that a message with invalid JSON is handled gracefully and an error is logged.
""" """
@@ -118,16 +121,16 @@ def test_process_belief_message_invalid_json(belief_setter, mocker, caplog):
msg = create_mock_message( msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs" sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs"
) )
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_not_called() mock_set_beliefs.assert_not_called()
def test_process_belief_message_wrong_thread(belief_setter, mocker): def test_process_belief_message_wrong_thread(belief_setter_behaviour, mocker):
""" """
Test that a message with an incorrect thread is ignored. Test that a message with an incorrect thread is ignored.
""" """
@@ -135,31 +138,31 @@ def test_process_belief_message_wrong_thread(belief_setter, mocker):
msg = create_mock_message( msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs" sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs"
) )
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_not_called() mock_set_beliefs.assert_not_called()
def test_process_belief_message_empty_body(belief_setter, mocker): def test_process_belief_message_empty_body(belief_setter_behaviour, mocker):
""" """
Test that a message with an empty body is ignored. Test that a message with an empty body is ignored.
""" """
# Arrange # Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs") msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs")
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_not_called() mock_set_beliefs.assert_not_called()
def test_set_beliefs_success(belief_setter, mock_agent, caplog): def test_set_beliefs_success(belief_setter_behaviour, mock_agent, caplog):
""" """
Test that beliefs are correctly set on the agent's BDI. Test that beliefs are correctly set on the agent's BDI.
""" """
@@ -171,7 +174,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
# Act # Act
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
belief_setter._set_beliefs(beliefs_to_set) belief_setter_behaviour._set_beliefs(beliefs_to_set)
# Assert # Assert
expected_calls = [ expected_calls = [
@@ -182,18 +185,18 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
assert mock_agent.bdi.set_belief.call_count == 2 assert mock_agent.bdi.set_belief.call_count == 2
# def test_responded_unset(belief_setter, mock_agent): # def test_responded_unset(belief_setter_behaviour, mock_agent):
# # Arrange # # Arrange
# new_beliefs = {"user_said": ["message"]} # new_beliefs = {"user_said": ["message"]}
# #
# # Act # # Act
# belief_setter._set_beliefs(new_beliefs) # belief_setter_behaviour._set_beliefs(new_beliefs)
# #
# # Assert # # Assert
# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")]) # mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")])
# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")]) # mock_agent.bdi.remove_belief.assert_has_calls([call("responded")])
# def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog): # def test_set_beliefs_bdi_not_initialized(belief_setter_behaviour, mock_agent, caplog):
# """ # """
# Test that a warning is logged if the agent's BDI is not initialized. # Test that a warning is logged if the agent's BDI is not initialized.
# """ # """
@@ -203,7 +206,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
# #
# # Act # # Act
# with caplog.at_level(logging.WARNING): # with caplog.at_level(logging.WARNING):
# belief_setter._set_beliefs(beliefs_to_set) # belief_setter_behaviour._set_beliefs(beliefs_to_set)
# #
# # Assert # # Assert
# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text # assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text

View File

@@ -0,0 +1,101 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.bdi.belief_collector_agent.behaviours.belief_collector_behaviour import ( # noqa: E501
BeliefCollectorBehaviour,
)
def create_mock_message(sender_node: str, body: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
return msg
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock Agent."""
agent = MagicMock()
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def bel_collector_behaviouror(mock_agent, mocker):
"""Fixture to create an instance of BelCollectorBehaviour with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = BeliefCollectorBehaviour()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_message_received(bel_collector_behaviouror, mocker):
"""
Test that when a message is received, _process_message is called with that message.
"""
# Arrange
mock_msg = MagicMock()
bel_collector_behaviouror.receive.return_value = mock_msg
mocker.patch.object(bel_collector_behaviouror, "_process_message")
# Act
await bel_collector_behaviouror.run()
# Assert
bel_collector_behaviouror._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(bel_collector_behaviouror, mocker):
msg = create_mock_message(
"anyone",
json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}),
)
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_belief_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(bel_collector_behaviouror, mocker):
msg = create_mock_message(
"bel_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
)
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_belief_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(bel_collector_behaviouror, mocker):
msg = create_mock_message("anyone", json.dumps({"type": "emotion_extraction_text"}))
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_emo_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_happy_path_sends(bel_collector_behaviouror, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
bel_collector_behaviouror.send = AsyncMock()
await bel_collector_behaviouror._handle_belief_text(payload, "bel_text_agent_mock")
# make sure we attempted a send
bel_collector_behaviouror.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(bel_collector_behaviouror, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
bel_collector_behaviouror.send = AsyncMock()
await bel_collector_behaviouror._handle_belief_text(payload, "origin")
bel_collector_behaviouror.send.assert_awaited_once()

View File

@@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from spade.message import Message from spade.message import Message
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
TextBeliefExtractorBehaviour,
)
@pytest.fixture @pytest.fixture
@@ -15,15 +17,17 @@ def mock_settings():
""" """
# Create a mock object that mimics the nested structure # Create a mock object that mimics the nested structure
settings_mock = MagicMock() settings_mock = MagicMock()
settings_mock.agent_settings.transcription_agent_name = "transcriber" settings_mock.agent_settings.transcription_name = "transcriber"
settings_mock.agent_settings.belief_collector_agent_name = "collector" settings_mock.agent_settings.bdi_belief_collector_name = "collector"
settings_mock.agent_settings.host = "fake.host" settings_mock.agent_settings.host = "fake.host"
# Use patch to replace the settings object during the test # Use patch to replace the settings object during the test
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where # Adjust 'control_backend.behaviours.belief_from_text.settings' to where
# your behaviour file imports it from. # your behaviour file imports it from.
with patch( with patch(
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock "control_backend.agents.bdi.text_belief_extractor_agent.behaviours"
".text_belief_extractor_behaviour.settings",
settings_mock,
): ):
yield settings_mock yield settings_mock
@@ -31,10 +35,10 @@ def mock_settings():
@pytest.fixture @pytest.fixture
def behavior(mock_settings): def behavior(mock_settings):
""" """
Creates an instance of the BeliefFromText behaviour and mocks its Creates an instance of the BDITextBeliefBehaviour behaviour and mocks its
agent, logger, send, and receive methods. agent, logger, send, and receive methods.
""" """
b = BeliefFromText() b = TextBeliefExtractorBehaviour()
b.agent = MagicMock() b.agent = MagicMock()
b.send = AsyncMock() b.send = AsyncMock()
@@ -100,7 +104,7 @@ async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkey
# Arrange: Create a mock message from the transcriber # Arrange: Create a mock message from the transcriber
transcription_text = "hello world" transcription_text = "hello world"
mock_msg = create_mock_message( mock_msg = create_mock_message(
mock_settings.agent_settings.transcription_agent_name, transcription_text, None mock_settings.agent_settings.transcription_name, transcription_text, None
) )
behavior.receive.return_value = mock_msg behavior.receive.return_value = mock_msg
@@ -119,7 +123,7 @@ async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkey
assert ( assert (
sent_msg.to sent_msg.to
== mock_settings.agent_settings.belief_collector_agent_name == mock_settings.agent_settings.bdi_belief_collector_name
+ "@" + "@"
+ mock_settings.agent_settings.host + mock_settings.agent_settings.host
) )
@@ -159,7 +163,7 @@ async def test_process_transcription_success(behavior, mock_settings):
# 2. Inspect the sent message # 2. Inspect the sent message
sent_msg: Message = behavior.send.call_args[0][0] sent_msg: Message = behavior.send.call_args[0][0]
expected_to = ( expected_to = (
mock_settings.agent_settings.belief_collector_agent_name mock_settings.agent_settings.bdi_belief_collector_name
+ "@" + "@"
+ mock_settings.agent_settings.host + mock_settings.agent_settings.host
) )

View File

@@ -1,101 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
ContinuousBeliefCollector,
)
def create_mock_message(sender_node: str, body: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
return msg
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock Agent."""
agent = MagicMock()
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def continuous_collector(mock_agent, mocker):
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = ContinuousBeliefCollector()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_message_received(continuous_collector, mocker):
"""
Test that when a message is received, _process_message is called with that message.
"""
# Arrange
mock_msg = MagicMock()
continuous_collector.receive.return_value = mock_msg
mocker.patch.object(continuous_collector, "_process_message")
# Act
await continuous_collector.run()
# Assert
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
msg = create_mock_message(
"anyone",
json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}),
)
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
msg = create_mock_message(
"belief_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
)
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
msg = create_mock_message("anyone", json.dumps({"type": "emotion_extraction_text"}))
spy = mocker.patch.object(continuous_collector, "_handle_emo_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_happy_path_sends(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
# make sure we attempted a send
continuous_collector.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "origin")
continuous_collector.send.assert_awaited_once()

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
from control_backend.agents.transcription.speech_recognizer import ( from control_backend.agents.perception.transcription_agent.speech_recognizer import (
OpenAIWhisperSpeechRecognizer, OpenAIWhisperSpeechRecognizer,
SpeechRecognizer, SpeechRecognizer,
) )

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
import zmq import zmq
from control_backend.agents.vad_agent import SocketPoller from control_backend.agents.perception.vad_agent import SocketPoller
@pytest.fixture @pytest.fixture
@@ -16,7 +16,7 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test" socket_data = b"test"
socket.recv.return_value = socket_data socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket) poller = SocketPoller(socket)
@@ -35,7 +35,7 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker): async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [] mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket) poller = SocketPoller(socket)

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np import numpy as np
import pytest import pytest
from control_backend.agents.vad_agent import Streaming from control_backend.agents.perception.vad_agent import StreamingBehaviour
@pytest.fixture @pytest.fixture
@@ -29,7 +29,7 @@ def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch import torch
torch.hub.load.return_value = (..., ...) # Mock torch.hub.load.return_value = (..., ...) # Mock
streaming = Streaming(audio_in_socket, audio_out_socket) streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
streaming._ready = True streaming._ready = True
streaming.agent = mock_agent streaming.agent = mock_agent
return streaming return streaming