Merge remote-tracking branch 'origin/dev' into refactor/config-file

# Conflicts:
#	src/control_backend/agents/ri_communication_agent.py
#	src/control_backend/core/config.py
#	src/control_backend/main.py
This commit is contained in:
Twirre Meulenbelt
2025-11-19 17:30:48 +01:00
46 changed files with 1207 additions and 651 deletions

View File

@@ -1,7 +1 @@
from .base import BaseAgent as BaseAgent from .base import BaseAgent as BaseAgent
from .belief_collector.belief_collector import BeliefCollectorAgent as BeliefCollectorAgent
from .llm.llm import LLMAgent as LLMAgent
from .ri_command_agent import RICommandAgent as RICommandAgent
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent
from .transcription.transcription_agent import TranscriptionAgent as TranscriptionAgent
from .vad_agent import VADAgent as VADAgent

View File

@@ -0,0 +1 @@
from .robot_speech_agent import RobotSpeechAgent as RobotSpeechAgent

View File

@@ -10,7 +10,7 @@ from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
class RICommandAgent(BaseAgent): class RobotSpeechAgent(BaseAgent):
subsocket: zmq.Socket subsocket: zmq.Socket
pubsocket: zmq.Socket pubsocket: zmq.Socket
address = "" address = ""
@@ -29,7 +29,7 @@ class RICommandAgent(BaseAgent):
self.address = address self.address = address
self.bind = bind self.bind = bind
class SendCommandsBehaviour(CyclicBehaviour): class SendZMQCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from the UI.""" """Behaviour for sending commands received from the UI."""
async def run(self): async def run(self):
@@ -50,7 +50,7 @@ class RICommandAgent(BaseAgent):
except Exception as e: except Exception as e:
self.agent.logger.error("Error processing message: %s", e) self.agent.logger.error("Error processing message: %s", e)
class SendPythonCommandsBehaviour(CyclicBehaviour): class SendSpadeCommandsBehaviour(CyclicBehaviour):
"""Behaviour for sending commands received from other Python agents.""" """Behaviour for sending commands received from other Python agents."""
async def run(self): async def run(self):
@@ -64,7 +64,7 @@ class RICommandAgent(BaseAgent):
async def setup(self): async def setup(self):
""" """
Setup the command agent Setup the robot speech command agent
""" """
self.logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.jid)
@@ -72,7 +72,7 @@ class RICommandAgent(BaseAgent):
# To the robot # To the robot
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
if self.bind: if self.bind: # TODO: Should this ever be the case?
self.pubsocket.bind(self.address) self.pubsocket.bind(self.address)
else: else:
self.pubsocket.connect(self.address) self.pubsocket.connect(self.address)
@@ -83,8 +83,8 @@ class RICommandAgent(BaseAgent):
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command") self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
# Add behaviour to our agent # Add behaviour to our agent
commands_behaviour = self.SendCommandsBehaviour() commands_behaviour = self.SendZMQCommandsBehaviour()
self.add_behaviour(commands_behaviour) self.add_behaviour(commands_behaviour)
self.add_behaviour(self.SendPythonCommandsBehaviour()) self.add_behaviour(self.SendSpadeCommandsBehaviour())
self.logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,2 +1,7 @@
from .bdi_core import BDICoreAgent as BDICoreAgent from .bdi_core_agent.bdi_core_agent import BDICoreAgent as BDICoreAgent
from .text_extractor import TBeliefExtractorAgent as TBeliefExtractorAgent from .belief_collector_agent.belief_collector_agent import (
BDIBeliefCollectorAgent as BDIBeliefCollectorAgent,
)
from .text_belief_extractor_agent.text_belief_extractor_agent import (
TextBeliefExtractorAgent as TextBeliefExtractorAgent,
)

View File

@@ -7,7 +7,7 @@ from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .behaviours.belief_setter import BeliefSetterBehaviour from .behaviours.belief_setter_behaviour import BeliefSetterBehaviour
from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour
@@ -57,7 +57,7 @@ class BDICoreAgent(BDIAgent):
class SendBehaviour(OneShotBehaviour): class SendBehaviour(OneShotBehaviour):
async def run(self) -> None: async def run(self) -> None:
msg = Message( msg = Message(
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host, to=settings.agent_settings.llm_name + "@" + settings.agent_settings.host,
body=text, body=text,
) )

View File

@@ -32,7 +32,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.agent.logger.debug("Processing message from sender: %s", sender) self.agent.logger.debug("Processing message from sender: %s", sender)
match sender: match sender:
case settings.agent_settings.belief_collector_agent_name: case settings.agent_settings.bdi_belief_collector_name:
self.agent.logger.debug( self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message." "Message is from the belief collector agent. Processing as belief message."
) )

View File

@@ -15,14 +15,14 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:
case settings.agent_settings.llm_agent_name: case settings.agent_settings.llm_name:
content = msg.body content = msg.body
self.agent.logger.info("Received LLM response: %s", content) self.agent.logger.info("Received LLM response: %s", content)
speech_command = SpeechCommand(data=content) speech_command = SpeechCommand(data=content)
message = Message( message = Message(
to=settings.agent_settings.ri_command_agent_name to=settings.agent_settings.robot_speech_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
sender=self.agent.jid, sender=self.agent.jid,

View File

@@ -7,7 +7,7 @@ from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings from control_backend.core.config import settings
class ContinuousBeliefCollector(CyclicBehaviour): class BeliefCollectorBehaviour(CyclicBehaviour):
""" """
Continuously collects beliefs/emotions from extractor agents: Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent. Then we send a unified belief packet to the BDI agent.
@@ -35,7 +35,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg_type = payload.get("type") msg_type = payload.get("type")
# Prefer explicit 'type' field # Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": if msg_type == "belief_extraction_text" or sender_node == "bel_text_agent_mock":
self.agent.logger.debug( self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node "Message routed to _handle_belief_text (sender=%s)", sender_node
) )
@@ -83,7 +83,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
if not beliefs: if not beliefs:
return return
to_jid = f"{settings.agent_settings.bdi_core_agent_name}@{settings.agent_settings.host}" to_jid = f"{settings.agent_settings.bdi_core_name}@{settings.agent_settings.host}"
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs) msg.body = json.dumps(beliefs)

View File

@@ -0,0 +1,11 @@
from control_backend.agents.base import BaseAgent
from .behaviours.belief_collector_behaviour import BeliefCollectorBehaviour
class BDIBeliefCollectorAgent(BaseAgent):
async def setup(self):
self.logger.info("BDIBeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(BeliefCollectorBehaviour())
self.logger.info("BDIBeliefCollectorAgent ready.")

View File

@@ -7,7 +7,7 @@ from spade.message import Message
from control_backend.core.config import settings from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour): class TextBeliefExtractorBehaviour(CyclicBehaviour):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: LLM prompt nog hardcoded # TODO: LLM prompt nog hardcoded
@@ -44,7 +44,7 @@ class BeliefFromText(CyclicBehaviour):
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:
case settings.agent_settings.transcription_agent_name: case settings.agent_settings.transcription_name:
self.logger.debug("Received text from transcriber: %s", msg.body) self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body) await self._process_transcription_demo(msg.body)
case _: case _:
@@ -71,7 +71,7 @@ class BeliefFromText(CyclicBehaviour):
belief_message = Message() belief_message = Message()
belief_message.to = ( belief_message.to = (
settings.agent_settings.belief_collector_agent_name settings.agent_settings.bdi_belief_collector_name
+ "@" + "@"
+ settings.agent_settings.host + settings.agent_settings.host
) )
@@ -95,7 +95,7 @@ class BeliefFromText(CyclicBehaviour):
belief_msg = Message() belief_msg = Message()
belief_msg.to = ( belief_msg.to = (
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host settings.agent_settings.bdi_belief_collector_name + "@" + settings.agent_settings.host
) )
belief_msg.body = payload belief_msg.body = payload
belief_msg.thread = "beliefs" belief_msg.thread = "beliefs"

View File

@@ -0,0 +1,8 @@
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor_behaviour import TextBeliefExtractorBehaviour
class TextBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(TextBeliefExtractorBehaviour())

View File

@@ -1,8 +0,0 @@
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(BeliefFromText())

View File

@@ -1,11 +0,0 @@
from control_backend.agents.base import BaseAgent
from .behaviours.continuous_collect import ContinuousBeliefCollector
class BeliefCollectorAgent(BaseAgent):
async def setup(self):
self.logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector())
self.logger.info("BeliefCollectorAgent ready.")

View File

@@ -0,0 +1 @@
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent

View File

@@ -0,0 +1,250 @@
import asyncio
import json
import zmq.asyncio
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from ..actuation.robot_speech_agent import RobotSpeechAgent
class RICommunicationAgent(BaseAgent):
req_socket: zmq.Socket
_address = ""
_bind = True
connected = False
def __init__(
self,
jid: str,
password: str,
port: int = settings.agent_settings.default_spade_port,
verify_security: bool = False,
address=settings.zmq_settings.ri_command_address,
bind=False,
):
super().__init__(jid, password, port, verify_security)
self._address = address
self._bind = bind
self._req_socket: zmq.asyncio.Socket | None = None
self.pub_socket: zmq.asyncio.Socket | None = None
class ListenBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the listening (ping) loop indefinetely.
"""
assert self.agent is not None
if not self.agent.connected:
await asyncio.sleep(settings.behaviour_settings.sleep_s)
return
# We need to listen and sent pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
seconds_to_wait_total = settings.behaviour_settings.sleep_s
try:
await asyncio.wait_for(
self.agent._req_socket.send_json(message), timeout=seconds_to_wait_total / 2
)
except TimeoutError:
self.agent.logger.debug(
"Waited too long to send message - "
"we probably dont have any receivers... but let's check!"
)
# Wait up to {seconds_to_wait_total/2} seconds for a reply
try:
message = await asyncio.wait_for(
self.agent._req_socket.recv_json(), timeout=seconds_to_wait_total / 2
)
# We didnt get a reply
except TimeoutError:
self.agent.logger.info(
f"No ping retrieved in {seconds_to_wait_total} seconds, "
"sending UI disconnection event and attempting to restart."
)
# Make sure we dont retry receiving messages untill we're setup.
self.agent.connected = False
self.agent.remove_behaviour(self)
# Tell UI we're disconnected.
topic = b"ping"
data = json.dumps(False).encode()
if self.agent.pub_socket is None:
self.agent.logger.warning(
"Communication agent pub socket not correctly initialized."
)
else:
try:
await asyncio.wait_for(
self.agent.pub_socket.send_multipart([topic, data]), 5
)
except TimeoutError:
self.agent.logger.warning(
f"Initial connection ping for router timed out in {self.agent.name}."
)
# Try to reboot.
self.agent.logger.debug("Restarting communication agent.")
await self.agent.setup()
self.agent.logger.debug(f'Received message "{message}" from RI.')
if "endpoint" not in message:
self.agent.logger.warning(
"No received endpoint in message, expected ping endpoint."
)
return
# See what endpoint we received
match message["endpoint"]:
case "ping":
topic = b"ping"
data = json.dumps(True).encode()
if self.agent.pub_socket is not None:
await self.agent.pub_socket.send_multipart([topic, data])
await asyncio.sleep(settings.behaviour_settings.sleep_s)
case _:
self.agent.logger.debug(
"Received message with topic different than ping, while ping expected."
)
async def setup_sockets(self, force=False):
"""
Sets up request socket for communication agent.
"""
# Bind request socket
if self._req_socket is None or force:
self._req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self._req_socket.bind(self._address)
else:
self._req_socket.connect(self._address)
if self.pub_socket is None or force:
self.pub_socket = Context.instance().socket(zmq.PUB)
self.pub_socket.connect(settings.zmq_settings.internal_pub_address)
async def setup(self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries):
"""
Try to set up the communication agent, we have `behaviour_settings.comm_setup_max_retries`
retries in case we don't have a response yet.
"""
self.logger.info("Setting up %s", self.jid)
# Bind request socket
await self.setup_sockets()
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Make sure the socket is properly setup.
if self._req_socket is None:
continue
# Send our message and receive one back
message = {"endpoint": "negotiate/ports", "data": {}}
await self._req_socket.send_json(message)
retry_frequency = 1.0
try:
received_message = await asyncio.wait_for(
self._req_socket.recv_json(), timeout=retry_frequency
)
except TimeoutError:
self.logger.warning(
"No connection established in %d seconds (attempt %d/%d)",
retries * retry_frequency,
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
self.logger.warning("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
self.logger.warning(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
await asyncio.sleep(1)
continue
# At this point, we have a valid response
try:
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
else:
addr = f"tcp://*:{port}"
match id:
case "main":
if addr != self._address:
if not bind:
self._req_socket.connect(addr)
else:
self._req_socket.bind(addr)
case "actuation":
ri_commands_agent = RobotSpeechAgent(
settings.agent_settings.robot_speech_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.robot_speech_name,
address=addr,
bind=bind,
)
await ri_commands_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e:
self.logger.warning("Error unpacking negotiation data: %s", e)
retries += 1
await asyncio.sleep(1)
continue
# setup succeeded
break
else:
self.logger.warning("Failed to set up %s after %d retries", self.name, max_retries)
return
# Set up ping behaviour
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
# Let UI know that we're connected
topic = b"ping"
data = json.dumps(True).encode()
if self.pub_socket is None:
self.logger.warning("Communication agent pub socket not correctly initialized.")
else:
try:
await asyncio.wait_for(self.pub_socket.send_multipart([topic, data]), 5)
except TimeoutError:
self.logger.warning("Initial connection ping for router timed out in com_ri_agent.")
# Make sure to start listening now that we're connected.
self.connected = True
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -0,0 +1 @@
from .llm_agent import LLMAgent as LLMAgent

View File

@@ -39,7 +39,7 @@ class LLMAgent(BaseAgent):
sender, sender,
) )
if sender == settings.agent_settings.bdi_core_agent_name: if sender == settings.agent_settings.bdi_core_name:
self.agent.logger.debug("Processing message from BDI Core Agent") self.agent.logger.debug("Processing message from BDI Core Agent")
await self._process_bdi_message(msg) await self._process_bdi_message(msg)
else: else:
@@ -63,7 +63,7 @@ class LLMAgent(BaseAgent):
Sends a response message back to the BDI Core Agent. Sends a response message back to the BDI Core Agent.
""" """
reply = Message( reply = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, to=settings.agent_settings.bdi_core_name + "@" + settings.agent_settings.host,
body=msg, body=msg,
) )
await self.send(reply) await self.send(reply)

View File

@@ -0,0 +1,4 @@
from .transcription_agent.transcription_agent import (
TranscriptionAgent as TranscriptionAgent,
)
from .vad_agent import VADAgent as VADAgent

View File

@@ -19,13 +19,13 @@ class TranscriptionAgent(BaseAgent):
""" """
def __init__(self, audio_in_address: str): def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host jid = settings.agent_settings.transcription_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.transcription_agent_name) super().__init__(jid, settings.agent_settings.transcription_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
class Transcribing(CyclicBehaviour): class TranscribingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket): def __init__(self, audio_in_socket: azmq.Socket):
super().__init__() super().__init__()
max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks max_concurrent_tasks = settings.behaviour_settings.transcription_max_concurrent_tasks
@@ -44,7 +44,7 @@ class TranscriptionAgent(BaseAgent):
async def _share_transcription(self, transcription: str): async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it.""" """Share a transcription to the other agents that depend on it."""
receiver_jids = [ receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name settings.agent_settings.text_belief_extractor_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
] # Set message receivers here ] # Set message receivers here
@@ -80,7 +80,7 @@ class TranscriptionAgent(BaseAgent):
self._connect_audio_in_socket() self._connect_audio_in_socket()
transcribing = self.Transcribing(self.audio_in_socket) transcribing = self.TranscribingBehaviour(self.audio_in_socket)
transcribing.warmup() transcribing.warmup()
self.add_behaviour(transcribing) self.add_behaviour(transcribing)

View File

@@ -7,7 +7,7 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .transcription.transcription_agent import TranscriptionAgent from .transcription_agent.transcription_agent import TranscriptionAgent
class SocketPoller[T]: class SocketPoller[T]:
@@ -44,7 +44,7 @@ class SocketPoller[T]:
return None return None
class Streaming(CyclicBehaviour): class StreamingBehaviour(CyclicBehaviour):
def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket): def __init__(self, audio_in_socket: azmq.Socket, audio_out_socket: azmq.Socket):
super().__init__() super().__init__()
self.audio_in_poller = SocketPoller[bytes](audio_in_socket) self.audio_in_poller = SocketPoller[bytes](audio_in_socket)
@@ -120,8 +120,8 @@ class VADAgent(BaseAgent):
""" """
def __init__(self, audio_in_address: str, audio_in_bind: bool): def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host jid = settings.agent_settings.vad_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name) super().__init__(jid, settings.agent_settings.vad_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind self.audio_in_bind = audio_in_bind
@@ -129,7 +129,7 @@ class VADAgent(BaseAgent):
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
self.audio_out_socket: azmq.Socket | None = None self.audio_out_socket: azmq.Socket | None = None
self.streaming_behaviour: Streaming | None = None self.streaming_behaviour: StreamingBehaviour | None = None
async def stop(self): async def stop(self):
""" """
@@ -173,7 +173,7 @@ class VADAgent(BaseAgent):
return return
audio_out_address = f"tcp://localhost:{audio_out_port}" audio_out_address = f"tcp://localhost:{audio_out_port}"
self.streaming_behaviour = Streaming(self.audio_in_socket, self.audio_out_socket) self.streaming_behaviour = StreamingBehaviour(self.audio_in_socket, self.audio_out_socket)
self.add_behaviour(self.streaming_behaviour) self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here

View File

@@ -1,162 +0,0 @@
import asyncio
import zmq
from spade.behaviour import CyclicBehaviour
from zmq.asyncio import Context
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .ri_command_agent import RICommandAgent
class RICommunicationAgent(BaseAgent):
req_socket: zmq.Socket
_address = ""
_bind = True
def __init__(
self,
jid: str,
password: str,
port: int = settings.agent_settings.default_spade_port,
verify_security: bool = False,
address=settings.zmq_settings.ri_command_address,
bind=False,
):
super().__init__(jid, password, port, verify_security)
self._address = address
self._bind = bind
class ListenBehaviour(CyclicBehaviour):
async def run(self):
"""
Run the listening (ping) loop indefinetely.
"""
assert self.agent is not None
# We need to listen and sent pings.
message = {"endpoint": "ping", "data": {"id": "e.g. some reference id"}}
await self.agent.req_socket.send_json(message)
# Wait up to three seconds for a reply:)
try:
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :(
except TimeoutError:
self.agent.logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
self.agent.logger.debug('Received message "%s"', message)
if "endpoint" not in message:
self.agent.logger.error("No received endpoint in message, excepted ping endpoint.")
return
# See what endpoint we received
match message["endpoint"]:
case "ping":
await asyncio.sleep(settings.behaviour_settings.ping_sleep_s)
case _:
self.agent.logger.info(
"Received message with topic different than ping, while ping expected."
)
async def setup(self, max_retries: int = settings.behaviour_settings.comm_setup_max_retries):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
"""
self.logger.info("Setting up %s", self.jid)
retries = 0
# Let's try a certain amount of times before failing connection
while retries < max_retries:
# Bind request socket
self.req_socket = Context.instance().socket(zmq.REQ)
if self._bind:
self.req_socket.bind(self._address)
else:
self.req_socket.connect(self._address)
# Send our message and receive one back:)
message = {"endpoint": "negotiate/ports", "data": None}
await self.req_socket.send_json(message)
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except TimeoutError:
self.logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,
max_retries,
)
retries += 1
continue
except Exception as e:
self.logger.error("Unexpected error during negotiation: %s", e)
retries += 1
continue
# Validate endpoint
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
# TODO: Should this send a message back?
self.logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
max_retries,
)
retries += 1
continue
# At this point, we have a valid response
try:
for port_data in received_message["data"]:
id = port_data["id"]
port = port_data["port"]
bind = port_data["bind"]
if not bind:
addr = f"tcp://localhost:{port}"
else:
addr = f"tcp://*:{port}"
match id:
case "main":
if addr != self._address:
if not bind:
self.req_socket.connect(addr)
else:
self.req_socket.bind(addr)
case "actuation":
ri_commands_agent = RICommandAgent(
settings.agent_settings.ri_command_agent_name
+ "@"
+ settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name,
address=addr,
bind=bind,
)
await ri_commands_agent.start()
case _:
self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e:
self.logger.error("Error unpacking negotiation data: %s", e)
retries += 1
continue
# setup succeeded
break
else:
self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
return
# Set up ping behaviour
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,20 +0,0 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}

View File

@@ -0,0 +1,25 @@
import logging
from fastapi import APIRouter, Request
from control_backend.schemas.program import Program
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/program", status_code=202)
async def receive_message(program: Program, request: Request):
"""
Receives a BehaviorProgram, pydantic checks it.
Converts it into real Phase objects.
"""
logger.debug("Received raw program: %s", program)
# send away
topic = b"program"
body = program.model_dump_json().encode()
pub_socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, body])
return {"status": "Program parsed"}

View File

@@ -0,0 +1,71 @@
import asyncio
import json
import logging
import zmq.asyncio
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from zmq.asyncio import Context, Socket
from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/command", status_code=202)
async def receive_command(command: SpeechCommand, request: Request):
# Validate and retrieve data.
SpeechCommand.model_validate(command)
topic = b"command"
pub_socket: Socket = request.app.state.endpoints_pub_socket
await pub_socket.send_multipart([topic, command.model_dump_json().encode()])
return {"status": "Command received"}
@router.get("/ping_check")
async def ping(request: Request):
pass
@router.get("/ping_stream")
async def ping_stream(request: Request):
"""Stream live updates whenever the device state changes."""
async def event_stream():
# Set up internal socket to receive ping updates
sub_socket = Context.instance().socket(zmq.SUB)
sub_socket.connect(settings.zmq_settings.internal_sub_address)
sub_socket.setsockopt(zmq.SUBSCRIBE, b"ping")
connected = False
ping_frequency = 2
# Even though its most likely the updates should alternate
# (So, True - False - True - False for connectivity),
# let's still check.
while True:
try:
topic, body = await asyncio.wait_for(
sub_socket.recv_multipart(), timeout=ping_frequency
)
connected = json.loads(body)
except TimeoutError:
logger.debug("got timeout error in ping loop in ping router")
connected = False
# Stop if client disconnected
if await request.is_disconnected():
logger.info("Client disconnected from SSE")
break
logger.debug(f"Yielded new connection event in robot ping router: {str(connected)}")
connectedJson = json.dumps(connected)
yield (f"data: {connectedJson}\n\n")
return StreamingResponse(event_stream(), media_type="text/event-stream")

View File

@@ -1,6 +1,6 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from control_backend.api.v1.endpoints import command, logs, message, sse from control_backend.api.v1.endpoints import logs, message, program, robot, sse
api_router = APIRouter() api_router = APIRouter()
@@ -8,6 +8,8 @@ api_router.include_router(message.router, tags=["Messages"])
api_router.include_router(sse.router, tags=["SSE"]) api_router.include_router(sse.router, tags=["SSE"])
api_router.include_router(command.router, tags=["Commands"]) api_router.include_router(robot.router, prefix="/robot", tags=["Pings", "Commands"])
api_router.include_router(logs.router, tags=["Logs"]) api_router.include_router(logs.router, tags=["Logs"])
api_router.include_router(program.router, tags=["Program"])

View File

@@ -15,22 +15,22 @@ class AgentSettings(BaseModel):
host: str = "localhost" host: str = "localhost"
# agent names # agent names
bdi_core_agent_name: str = "bdi_core" bdi_core_name: str = "bdi_core_agent"
belief_collector_agent_name: str = "belief_collector" bdi_belief_collector_name: str = "belief_collector_agent"
text_belief_extractor_agent_name: str = "text_belief_extractor" text_belief_extractor_name: str = "text_belief_extractor_agent"
vad_agent_name: str = "vad_agent" vad_name: str = "vad_agent"
llm_agent_name: str = "llm_agent" llm_name: str = "llm_agent"
test_agent_name: str = "test_agent" test_name: str = "test_agent"
transcription_agent_name: str = "transcription_agent" transcription_name: str = "transcription_agent"
ri_communication_agent_name: str = "ri_communication_agent" ri_communication_name: str = "ri_communication_agent"
ri_command_agent_name: str = "ri_command_agent" robot_speech_name: str = "robot_speech_agent"
# default SPADE port # default SPADE port
default_spade_port: int = 5222 default_spade_port: int = 5222
class BehaviourSettings(BaseModel): class BehaviourSettings(BaseModel):
ping_sleep_s: float = 1.0 sleep_s: float = 1.0
comm_setup_max_retries: int = 5 comm_setup_max_retries: int = 5
socket_poller_timeout_ms: int = 100 socket_poller_timeout_ms: int = 100

View File

@@ -7,13 +7,25 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context from zmq.asyncio import Context
from control_backend.agents import ( # Act agents
BeliefCollectorAgent, # BDI agents
LLMAgent, from control_backend.agents.bdi import (
RICommunicationAgent, BDIBeliefCollectorAgent,
VADAgent, BDICoreAgent,
TextBeliefExtractorAgent,
) )
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
# Communication agents
from control_backend.agents.communication import RICommunicationAgent
# Emotional Agents
# LLM Agents
from control_backend.agents.llm import LLMAgent
# Perceive agents
from control_backend.agents.perception import VADAgent
# Other backend imports
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
@@ -67,10 +79,10 @@ async def lifespan(app: FastAPI):
"RICommunicationAgent": ( "RICommunicationAgent": (
RICommunicationAgent, RICommunicationAgent,
{ {
"name": settings.agent_settings.ri_communication_agent_name, "name": settings.agent_settings.ri_communication_name,
"jid": f"{settings.agent_settings.ri_communication_agent_name}" "jid": f"{settings.agent_settings.ri_communication_name}"
f"@{settings.agent_settings.host}", f"@{settings.agent_settings.host}",
"password": settings.agent_settings.ri_communication_agent_name, "password": settings.agent_settings.ri_communication_name,
"address": settings.zmq_settings.ri_communication_address, "address": settings.zmq_settings.ri_communication_address,
"bind": True, "bind": True,
}, },
@@ -78,37 +90,36 @@ async def lifespan(app: FastAPI):
"LLMAgent": ( "LLMAgent": (
LLMAgent, LLMAgent,
{ {
"name": settings.agent_settings.llm_agent_name, "name": settings.agent_settings.llm_name,
"jid": f"{settings.agent_settings.llm_agent_name}@{settings.agent_settings.host}", "jid": f"{settings.agent_settings.llm_name}@{settings.agent_settings.host}",
"password": settings.agent_settings.llm_agent_name, "password": settings.agent_settings.llm_name,
}, },
), ),
"BDICoreAgent": ( "BDICoreAgent": (
BDICoreAgent, BDICoreAgent,
{ {
"name": settings.agent_settings.bdi_core_agent_name, "name": settings.agent_settings.bdi_core_name,
"jid": f"{settings.agent_settings.bdi_core_agent_name}@" "jid": f"{settings.agent_settings.bdi_core_name}@{settings.agent_settings.host}",
f"{settings.agent_settings.host}", "password": settings.agent_settings.bdi_core_name,
"password": settings.agent_settings.bdi_core_agent_name, "asl": "src/control_backend/agents/bdi/bdi_core_agent/rules.asl",
"asl": "src/control_backend/agents/bdi/rules.asl",
}, },
), ),
"BeliefCollectorAgent": ( "BeliefCollectorAgent": (
BeliefCollectorAgent, BDIBeliefCollectorAgent,
{ {
"name": settings.agent_settings.belief_collector_agent_name, "name": settings.agent_settings.bdi_belief_collector_name,
"jid": f"{settings.agent_settings.belief_collector_agent_name}@" "jid": f"{settings.agent_settings.bdi_belief_collector_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.belief_collector_agent_name, "password": settings.agent_settings.bdi_belief_collector_name,
}, },
), ),
"TBeliefExtractor": ( "TextBeliefExtractorAgent": (
TBeliefExtractorAgent, TextBeliefExtractorAgent,
{ {
"name": settings.agent_settings.text_belief_extractor_agent_name, "name": settings.agent_settings.text_belief_extractor_name,
"jid": f"{settings.agent_settings.text_belief_extractor_agent_name}@" "jid": f"{settings.agent_settings.text_belief_extractor_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.text_belief_extractor_agent_name, "password": settings.agent_settings.text_belief_extractor_name,
}, },
), ),
"VADAgent": ( "VADAgent": (
@@ -117,17 +128,23 @@ async def lifespan(app: FastAPI):
), ),
} }
vad_agent = None
for name, (agent_class, kwargs) in agents_to_start.items(): for name, (agent_class, kwargs) in agents_to_start.items():
try: try:
logger.debug("Starting agent: %s", name) logger.debug("Starting agent: %s", name)
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"}) agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"})
await agent_instance.start() await agent_instance.start()
if isinstance(agent_instance, VADAgent):
vad_agent = agent_instance
logger.info("Agent '%s' started successfully.", name) logger.info("Agent '%s' started successfully.", name)
except Exception as e: except Exception as e:
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True) logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
# Consider if the application should continue if an agent fails to start. # Consider if the application should continue if an agent fails to start.
raise raise
await vad_agent.streaming_behaviour.reset()
logger.info("Application startup complete.") logger.info("Application startup complete.")
yield yield

View File

@@ -0,0 +1,38 @@
from pydantic import BaseModel
class Norm(BaseModel):
id: str
name: str
value: str
class Goal(BaseModel):
id: str
name: str
description: str
achieved: bool
class Trigger(BaseModel):
id: str
label: str
type: str
value: list[str]
class PhaseData(BaseModel):
norms: list[Norm]
goals: list[Goal]
triggers: list[Trigger]
class Phase(BaseModel):
id: str
name: str
nextPhaseId: str
phaseData: PhaseData
class Program(BaseModel):
phases: list[Phase]

View File

@@ -4,12 +4,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import zmq import zmq
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.actuation.robot_speech_agent import RobotSpeechAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.actuation.robot_speech_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@@ -19,20 +21,18 @@ async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True""" """Test setup with bind=True"""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup() await agent.setup()
# Ensure PUB socket bound
fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.bind.assert_any_call("tcp://localhost:5555")
# Ensure SUB socket connected to internal address and subscribed
fake_socket.connect.assert_any_call("tcp://internal:1234") fake_socket.connect.assert_any_call("tcp://internal:1234")
fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command") fake_socket.setsockopt.assert_any_call(zmq.SUBSCRIBE, b"command")
# Ensure behaviour attached # Ensure behaviour attached
assert any(isinstance(b, agent.SendCommandsBehaviour) for b in agent.behaviours) assert any(isinstance(b, agent.SendZMQCommandsBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -40,13 +40,12 @@ async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = RobotSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings") settings = mocker.patch("control_backend.agents.actuation.robot_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup() await agent.setup()
# Ensure PUB socket connected
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@@ -60,14 +59,16 @@ async def test_send_commands_behaviour_valid_message():
) )
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = RobotSpeechAgent("test@server", "password")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendZMQCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand: with patch(
"control_backend.agents.actuation.robot_speech_agent.SpeechCommand"
) as MockSpeechCommand:
mock_message = MagicMock() mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message MockSpeechCommand.model_validate.return_value = mock_message
@@ -78,22 +79,20 @@ async def test_send_commands_behaviour_valid_message():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_commands_behaviour_invalid_message(caplog): async def test_send_commands_behaviour_invalid_message():
"""Test behaviour with invalid JSON message triggers error logging""" """Test behaviour with invalid JSON message triggers error logging"""
fake_socket = AsyncMock() fake_socket = AsyncMock()
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}")) fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = RobotSpeechAgent("test@server", "password")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendZMQCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with caplog.at_level("ERROR"): await behaviour.run()
await behaviour.run()
fake_socket.recv_multipart.assert_awaited() fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_not_awaited() fake_socket.send_json.assert_not_awaited()
assert "Error processing message" in caplog.text

View File

@@ -3,7 +3,11 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest import pytest
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.communication.ri_communication_agent import RICommunicationAgent
def speech_agent_path():
return "control_backend.agents.communication.ri_communication_agent.RobotSpeechAgent"
def fake_json_correct_negototiate_1(): def fake_json_correct_negototiate_1():
@@ -86,7 +90,9 @@ def fake_json_invalid_id_negototiate():
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.communication.ri_communication_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@@ -100,23 +106,25 @@ async def test_setup_creates_socket_and_negotiate_1(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
"password",
address="tcp://localhost:5555",
bind=False,
) )
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited() fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with( MockCommandAgent.assert_called_once_with(
@@ -138,23 +146,25 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
"password",
address="tcp://localhost:5555",
bind=False,
) )
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited() fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with( MockCommandAgent.assert_called_once_with(
@@ -168,7 +178,7 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog): async def test_setup_creates_socket_and_negotiate_3(zmq_context):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
@@ -176,23 +186,24 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("ERROR"):
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
) "password",
await agent.setup(max_retries=1) address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@@ -200,7 +211,6 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "Failed to set up RICommunicationAgent" in caplog.text
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@@ -215,23 +225,24 @@ async def test_setup_creates_socket_and_negotiate_4(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=True "test@server",
"password",
address="tcp://localhost:5555",
bind=True,
) )
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
fake_socket.bind.assert_any_call("tcp://localhost:5555") fake_socket.bind.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited() fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with( MockCommandAgent.assert_called_once_with(
@@ -253,23 +264,24 @@ async def test_setup_creates_socket_and_negotiate_5(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
"password",
address="tcp://localhost:5555",
bind=False,
) )
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited() fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with( MockCommandAgent.assert_called_once_with(
@@ -291,23 +303,24 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
"password",
address="tcp://localhost:5555",
bind=False,
) )
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": None}) fake_socket.send_json.assert_any_call({"endpoint": "negotiate/ports", "data": {}})
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
fake_agent_instance.start.assert_awaited() fake_agent_instance.start.assert_awaited()
MockCommandAgent.assert_called_once_with( MockCommandAgent.assert_called_once_with(
@@ -321,7 +334,7 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog): async def test_setup_creates_socket_and_negotiate_7(zmq_context):
""" """
Test the functionality of setup with incorrect id Test the functionality of setup with incorrect id
""" """
@@ -329,23 +342,25 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
fake_socket.send_multipart = AsyncMock()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
) "password",
await agent.setup(max_retries=1) address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
@@ -353,11 +368,10 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "Unhandled negotiation id:" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog): async def test_setup_creates_socket_and_negotiate_timeout(zmq_context):
""" """
Test the functionality of setup with incorrect negotiation message Test the functionality of setup with incorrect negotiation message
""" """
@@ -365,55 +379,54 @@ async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"):
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
) "password",
await agent.setup(max_retries=1) address="tcp://localhost:5555",
bind=False,
)
await agent.setup(max_retries=1)
# --- Assert --- # --- Assert ---
fake_socket.connect.assert_any_call("tcp://localhost:5555") fake_socket.connect.assert_any_call("tcp://localhost:5555")
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "No connection established in 20 seconds" in caplog.text
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_ping_correct(caplog): async def test_listen_behaviour_ping_correct():
fake_socket = AsyncMock() fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
fake_socket.send_multipart = AsyncMock()
# TODO: Integration test between actual server and password needed for spade agents
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) await behaviour.run()
with caplog.at_level("DEBUG"):
await behaviour.run()
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
assert "Received message" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_ping_wrong_endpoint(caplog): async def test_listen_behaviour_ping_wrong_endpoint():
""" """
Test if our listen behaviour can work with wrong messages (wrong endpoint) Test if our listen behaviour can work with wrong messages (wrong endpoint)
""" """
@@ -430,48 +443,51 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
], ],
} }
) )
fake_pub_socket = AsyncMock()
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password", fake_pub_socket)
agent.req_socket = fake_socket agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) # Run once (CyclicBehaviour normally loops)
with caplog.at_level("INFO"):
await behaviour.run()
assert "Received message with topic different than ping, while ping expected." in caplog.text await behaviour.run()
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_timeout(zmq_context, caplog): async def test_listen_behaviour_timeout(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# recv_json will never resolve, simulate timeout # recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
with caplog.at_level("INFO"): await behaviour.run()
await behaviour.run() assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
assert not agent.connected
assert "No ping retrieved in 3 seconds" in caplog.text
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_listen_behaviour_ping_no_endpoint(caplog): async def test_listen_behaviour_ping_no_endpoint():
""" """
Test if our listen behaviour can work with wrong messages (wrong endpoint) Test if our listen behaviour can work with wrong messages (wrong endpoint)
""" """
fake_socket = AsyncMock() fake_socket = AsyncMock()
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.send_multipart = AsyncMock()
# This is a message without endpoint >:( # This is a message without endpoint >:(
fake_socket.recv_json = AsyncMock( fake_socket.recv_json = AsyncMock(
@@ -481,43 +497,45 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
) )
agent = RICommunicationAgent("test@server", "password") agent = RICommunicationAgent("test@server", "password")
agent.req_socket = fake_socket agent._req_socket = fake_socket
agent.connected = True
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
agent.add_behaviour(behaviour) agent.add_behaviour(behaviour)
# Run once (CyclicBehaviour normally loops) await behaviour.run()
with caplog.at_level("ERROR"):
await behaviour.run()
assert "No received endpoint in message, excepted ping endpoint." in caplog.text
fake_socket.send_json.assert_awaited() fake_socket.send_json.assert_awaited()
fake_socket.recv_json.assert_awaited() fake_socket.recv_json.assert_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unexpected_exception(zmq_context, caplog): async def test_setup_unexpected_exception(zmq_context):
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
# Simulate unexpected exception during recv_json() # Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
fake_socket.send_multipart = AsyncMock()
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
"password",
address="tcp://localhost:5555",
bind=False,
) )
with caplog.at_level("ERROR"): await agent.setup(max_retries=1)
await agent.setup(max_retries=1)
# Ensure that the error was logged assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
assert "Unexpected error during negotiation: boom!" in caplog.text assert not agent.connected
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_setup_unpacking_exception(zmq_context, caplog): async def test_setup_unpacking_exception(zmq_context):
# --- Arrange --- # --- Arrange ---
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.send_multipart = AsyncMock()
# Make recv_json return malformed negotiation data to trigger unpacking exception # Make recv_json return malformed negotiation data to trigger unpacking exception
malformed_data = { malformed_data = {
@@ -526,23 +544,21 @@ async def test_setup_unpacking_exception(zmq_context, caplog):
} # missing 'port' and 'bind' } # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch RICommandAgent so it won't actually start # Patch ActSpeechAgent so it won't actually start
with patch( with patch(speech_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent( agent = RICommunicationAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server",
"password",
address="tcp://localhost:5555",
bind=False,
) )
# --- Act & Assert --- # --- Act & Assert ---
with caplog.at_level("ERROR"):
await agent.setup(max_retries=1)
# Ensure the unpacking exception was logged await agent.setup(max_retries=1)
assert "Error unpacking negotiation data" in caplog.text
# Ensure no command agent was started # Ensure no command agent was started
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()

View File

@@ -5,43 +5,45 @@ import pytest
import zmq import zmq
from spade.agent import Agent from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent from control_backend.agents.perception.vad_agent import VADAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch("control_backend.agents.perception.vad_agent.azmq.Context.instance")
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@pytest.fixture @pytest.fixture
def streaming(mocker): def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming") return mocker.patch("control_backend.agents.perception.vad_agent.StreamingBehaviour")
@pytest.fixture @pytest.fixture
def transcription_agent(mocker): def per_transcription_agent(mocker):
return mocker.patch("control_backend.agents.vad_agent.TranscriptionAgent", autospec=True) return mocker.patch(
"control_backend.agents.perception.vad_agent.TranscriptionAgent", autospec=True
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normal_setup(streaming, transcription_agent): async def test_normal_setup(streaming, per_transcription_agent):
""" """
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models. sockets, and starts the TranscriptionAgent without loading real models.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock() per_vad_agent.add_behaviour = MagicMock()
await vad_agent.setup() await per_vad_agent.setup()
streaming.assert_called_once() streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
transcription_agent.assert_called_once() per_transcription_agent.assert_called_once()
transcription_agent.return_value.start.assert_called_once() per_transcription_agent.return_value.start.assert_called_once()
assert vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False]) @pytest.mark.parametrize("do_bind", [True, False])
@@ -50,11 +52,11 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
Test that the VAD agent creates an audio input socket, differentiating between binding and Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting. connecting.
""" """
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind) per_vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
vad_agent._connect_audio_in_socket() per_vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
@@ -74,11 +76,11 @@ def test_out_socket_creation(zmq_context):
""" """
Test that the VAD agent creates an audio output socket correctly. Test that the VAD agent creates an audio output socket correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket() per_vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@@ -93,28 +95,28 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError zmq.ZMQBindError
) )
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
await vad_agent.setup() await per_vad_agent.setup()
assert vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once() mock_super_stop.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop(zmq_context, transcription_agent): async def test_stop(zmq_context, per_transcription_agent):
""" """
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = VADAgent("tcp://localhost:12345", False)
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000, 1000,
10000, 10000,
) )
await vad_agent.setup() await per_vad_agent.setup()
await vad_agent.stop() await per_vad_agent.stop()
assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None assert per_vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None

View File

@@ -5,7 +5,7 @@ import pytest
import soundfile as sf import soundfile as sf
import zmq import zmq
from control_backend.agents.vad_agent import Streaming from control_backend.agents.perception.vad_agent import StreamingBehaviour
def get_audio_chunks() -> list[bytes]: def get_audio_chunks() -> list[bytes]:
@@ -42,12 +42,12 @@ async def test_real_audio(mocker):
audio_in_socket = AsyncMock() audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock() audio_out_socket = AsyncMock()
vad_streamer = Streaming(audio_in_socket, audio_out_socket) vad_streamer = StreamingBehaviour(audio_in_socket, audio_out_socket)
vad_streamer._ready = True vad_streamer._ready = True
vad_streamer.agent = MagicMock() vad_streamer.agent = MagicMock()
for _ in audio_chunks: for _ in audio_chunks:

View File

@@ -1,61 +0,0 @@
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import command
from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture
def app():
"""
Creates a FastAPI test app and attaches the router under test.
Also sets up a mock internal_comm_socket.
"""
app = FastAPI()
app.include_router(command.router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
def test_receive_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
def test_receive_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
assert response.status_code == 422 # validation error

View File

@@ -0,0 +1,125 @@
import json
from unittest.mock import AsyncMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import program
from control_backend.schemas.program import Program
@pytest.fixture
def app():
"""Create a FastAPI app with the /program route and mock socket."""
app = FastAPI()
app.include_router(program.router)
return app
@pytest.fixture
def client(app):
"""Create a TestClient."""
return TestClient(app)
def make_valid_program_dict():
"""Helper to create a valid Program JSON structure."""
return {
"phases": [
{
"id": "phase1",
"name": "basephase",
"nextPhaseId": "phase2",
"phaseData": {
"norms": [{"id": "n1", "name": "norm", "value": "be nice"}],
"goals": [
{"id": "g1", "name": "goal", "description": "test goal", "achieved": False}
],
"triggers": [
{
"id": "t1",
"label": "trigger",
"type": "keyword",
"value": ["stop", "exit"],
}
],
},
}
]
}
def test_receive_program_success(client):
"""Valid Program JSON should be parsed and sent through the socket."""
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
program_dict = make_valid_program_dict()
response = client.post("/program", json=program_dict)
assert response.status_code == 202
assert response.json() == {"status": "Program parsed"}
# Verify socket call
mock_pub_socket.send_multipart.assert_awaited_once()
args, kwargs = mock_pub_socket.send_multipart.await_args
assert args[0][0] == b"program"
sent_bytes = args[0][1]
sent_obj = json.loads(sent_bytes.decode())
expected_obj = Program.model_validate(program_dict).model_dump()
assert sent_obj == expected_obj
def test_receive_program_invalid_json(client):
"""
Invalid JSON (malformed) -> FastAPI never calls endpoint.
It returns a 422 Unprocessable Entity.
"""
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# FastAPI only accepts valid JSON bodies, so send raw string
response = client.post("/program", content="{invalid json}")
assert response.status_code == 422
mock_pub_socket.send_multipart.assert_not_called()
def test_receive_program_invalid_deep_structure(client):
"""
Valid JSON but schema invalid -> Pydantic throws validation error -> 422.
"""
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
# Missing "value" in norms element
bad_program = {
"phases": [
{
"id": "phase1",
"name": "deepfail",
"nextPhaseId": "phase2",
"phaseData": {
"norms": [
{"id": "n1", "name": "norm"} # INVALID: missing "value"
],
"goals": [
{"id": "g1", "name": "goal", "description": "desc", "achieved": False}
],
"triggers": [
{"id": "t1", "label": "trigger", "type": "keyword", "value": ["start"]}
],
},
}
]
}
response = client.post("/program", json=bad_program)
assert response.status_code == 422
mock_pub_socket.send_multipart.assert_not_called()

View File

@@ -0,0 +1,156 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from control_backend.api.v1.endpoints import robot
from control_backend.schemas.ri_message import SpeechCommand
@pytest.fixture
def app():
"""
Creates a FastAPI test app and attaches the router under test.
Also sets up a mock internal_comm_socket.
"""
app = FastAPI()
app.include_router(robot.router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
def test_receive_command_success(client):
"""
Test for successful reception of a command. Ensures the status code is 202 and the response body
is correct. It also verifies that the ZeroMQ socket's send_multipart method is called with the
expected data.
"""
# Arrange
mock_pub_socket = AsyncMock()
client.app.state.endpoints_pub_socket = mock_pub_socket
command_data = {"endpoint": "actuate/speech", "data": "This is a test"}
speech_command = SpeechCommand(**command_data)
# Act
response = client.post("/command", json=command_data)
# Assert
assert response.status_code == 202
assert response.json() == {"status": "Command received"}
# Verify that the ZMQ socket was used correctly
mock_pub_socket.send_multipart.assert_awaited_once_with(
[b"command", speech_command.model_dump_json().encode()]
)
def test_receive_command_invalid_payload(client):
"""
Test invalid data handling (schema validation).
"""
# Missing required field(s)
bad_payload = {"invalid": "data"}
response = client.post("/command", json=bad_payload)
assert response.status_code == 422 # validation error
def test_ping_check_returns_none(client):
"""Ensure /ping_check returns 200 and None (currently unimplemented)."""
response = client.get("/ping_check")
assert response.status_code == 200
assert response.json() is None
@pytest.mark.asyncio
async def test_ping_stream_yields_ping_event(monkeypatch):
"""Test that ping_stream yields a proper SSE message when a ping is received."""
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
mock_sub_socket.recv_multipart = AsyncMock(return_value=[b"ping", b"true"])
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
response = await robot.ping_stream(mock_request)
generator = aiter(response.body_iterator)
event = await anext(generator)
event_text = event.decode() if isinstance(event, bytes) else str(event)
assert event_text.strip() == "data: true"
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_ping_stream_handles_timeout(monkeypatch):
"""Test that ping_stream continues looping on TimeoutError."""
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
mock_sub_socket.recv_multipart.side_effect = TimeoutError()
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
response = await robot.ping_stream(mock_request)
generator = aiter(response.body_iterator)
with pytest.raises(StopAsyncIteration):
await anext(generator)
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()
@pytest.mark.asyncio
async def test_ping_stream_yields_json_values(monkeypatch):
"""Ensure ping_stream correctly parses and yields JSON body values."""
mock_sub_socket = AsyncMock()
mock_sub_socket.connect = MagicMock()
mock_sub_socket.setsockopt = MagicMock()
mock_sub_socket.recv_multipart = AsyncMock(
return_value=[b"ping", json.dumps({"connected": True}).encode()]
)
mock_context = MagicMock()
mock_context.socket.return_value = mock_sub_socket
monkeypatch.setattr(robot.Context, "instance", lambda: mock_context)
mock_request = AsyncMock()
mock_request.is_disconnected = AsyncMock(side_effect=[False, True])
response = await robot.ping_stream(mock_request)
generator = aiter(response.body_iterator)
event = await anext(generator)
event_text = event.decode() if isinstance(event, bytes) else str(event)
assert "connected" in event_text
assert "true" in event_text
mock_sub_socket.connect.assert_called_once()
mock_sub_socket.setsockopt.assert_called_once_with(robot.zmq.SUBSCRIBE, b"ping")
mock_sub_socket.recv_multipart.assert_awaited()

View File

@@ -0,0 +1,85 @@
import pytest
from pydantic import ValidationError
from control_backend.schemas.program import Goal, Norm, Phase, PhaseData, Program, Trigger
def base_norm() -> Norm:
return Norm(
id="norm1",
name="testNorm",
value="you should act nice",
)
def base_goal() -> Goal:
return Goal(
id="goal1",
name="testGoal",
description="you should act nice",
achieved=False,
)
def base_trigger() -> Trigger:
return Trigger(
id="trigger1",
label="testTrigger",
type="keyword",
value=["Stop", "Exit"],
)
def base_phase_data() -> PhaseData:
return PhaseData(
norms=[base_norm()],
goals=[base_goal()],
triggers=[base_trigger()],
)
def base_phase() -> Phase:
return Phase(
id="phase1",
name="basephase",
nextPhaseId="phase2",
phaseData=base_phase_data(),
)
def base_program() -> Program:
return Program(phases=[base_phase()])
def invalid_program() -> dict:
# wrong types inside phases list (not Phase objects)
return {
"phases": [
{"id": "phase1"}, # incomplete
{"not_a_phase": True},
]
}
def test_valid_program():
program = base_program()
validated = Program.model_validate(program)
assert isinstance(validated, Program)
assert validated.phases[0].phaseData.norms[0].name == "testNorm"
def test_valid_deepprogram():
program = base_program()
validated = Program.model_validate(program)
# validate nested components directly
phase = validated.phases[0]
assert isinstance(phase.phaseData, PhaseData)
assert isinstance(phase.phaseData.goals[0], Goal)
assert isinstance(phase.phaseData.triggers[0], Trigger)
assert isinstance(phase.phaseData.norms[0], Norm)
def test_invalid_program():
bad = invalid_program()
with pytest.raises(ValidationError):
Program.model_validate(bad)

View File

@@ -4,10 +4,12 @@ from unittest.mock import AsyncMock, MagicMock, call
import pytest import pytest
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour from control_backend.agents.bdi.bdi_core_agent.behaviours.belief_setter_behaviour import (
BeliefSetterBehaviour,
)
# Define a constant for the collector agent name to use in tests # Define a constant for the collector agent name to use in tests
COLLECTOR_AGENT_NAME = "belief_collector" COLLECTOR_AGENT_NAME = "belief_collector_agent"
COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test" COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test"
@@ -21,11 +23,12 @@ def mock_agent(mocker):
@pytest.fixture @pytest.fixture
def belief_setter(mock_agent, mocker): def belief_setter_behaviour(mock_agent, mocker):
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent.""" """Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
# Patch the settings to use a predictable agent name # Patch the settings to use a predictable agent name
mocker.patch( mocker.patch(
"control_backend.agents.bdi.behaviours.belief_setter.settings.agent_settings.belief_collector_agent_name", "control_backend.agents.bdi.bdi_core_agent."
"behaviours.belief_setter_behaviour.settings.agent_settings.bdi_belief_collector_name",
COLLECTOR_AGENT_NAME, COLLECTOR_AGENT_NAME,
) )
@@ -46,53 +49,53 @@ def create_mock_message(sender_node: str, body: str, thread: str) -> MagicMock:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_run_message_received(belief_setter, mocker): async def test_run_message_received(belief_setter_behaviour, mocker):
""" """
Test that when a message is received, _process_message is called. Test that when a message is received, _process_message is called.
""" """
# Arrange # Arrange
msg = MagicMock() msg = MagicMock()
belief_setter.receive.return_value = msg belief_setter_behaviour.receive.return_value = msg
mocker.patch.object(belief_setter, "_process_message") mocker.patch.object(belief_setter_behaviour, "_process_message")
# Act # Act
await belief_setter.run() await belief_setter_behaviour.run()
# Assert # Assert
belief_setter._process_message.assert_called_once_with(msg) belief_setter_behaviour._process_message.assert_called_once_with(msg)
def test_process_message_from_belief_collector(belief_setter, mocker): def test_process_message_from_bdi_belief_collector_agent(belief_setter_behaviour, mocker):
""" """
Test processing a message from the correct belief collector agent. Test processing a message from the correct belief collector agent.
""" """
# Arrange # Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="") msg = create_mock_message(sender_node=COLLECTOR_AGENT_NAME, body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter, "_process_belief_message") mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
# Act # Act
belief_setter._process_message(msg) belief_setter_behaviour._process_message(msg)
# Assert # Assert
mock_process_belief.assert_called_once_with(msg) mock_process_belief.assert_called_once_with(msg)
def test_process_message_from_other_agent(belief_setter, mocker): def test_process_message_from_other_agent(belief_setter_behaviour, mocker):
""" """
Test that messages from other agents are ignored. Test that messages from other agents are ignored.
""" """
# Arrange # Arrange
msg = create_mock_message(sender_node="other_agent", body="", thread="") msg = create_mock_message(sender_node="other_agent", body="", thread="")
mock_process_belief = mocker.patch.object(belief_setter, "_process_belief_message") mock_process_belief = mocker.patch.object(belief_setter_behaviour, "_process_belief_message")
# Act # Act
belief_setter._process_message(msg) belief_setter_behaviour._process_message(msg)
# Assert # Assert
mock_process_belief.assert_not_called() mock_process_belief.assert_not_called()
def test_process_belief_message_valid_json(belief_setter, mocker): def test_process_belief_message_valid_json(belief_setter_behaviour, mocker):
""" """
Test processing a valid belief message with correct thread and JSON body. Test processing a valid belief message with correct thread and JSON body.
""" """
@@ -101,16 +104,16 @@ def test_process_belief_message_valid_json(belief_setter, mocker):
msg = create_mock_message( msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs" sender_node=COLLECTOR_AGENT_JID, body=json.dumps(beliefs_payload), thread="beliefs"
) )
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_called_once_with(beliefs_payload) mock_set_beliefs.assert_called_once_with(beliefs_payload)
def test_process_belief_message_invalid_json(belief_setter, mocker, caplog): def test_process_belief_message_invalid_json(belief_setter_behaviour, mocker, caplog):
""" """
Test that a message with invalid JSON is handled gracefully and an error is logged. Test that a message with invalid JSON is handled gracefully and an error is logged.
""" """
@@ -118,16 +121,16 @@ def test_process_belief_message_invalid_json(belief_setter, mocker, caplog):
msg = create_mock_message( msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs" sender_node=COLLECTOR_AGENT_JID, body="this is not a json string", thread="beliefs"
) )
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_not_called() mock_set_beliefs.assert_not_called()
def test_process_belief_message_wrong_thread(belief_setter, mocker): def test_process_belief_message_wrong_thread(belief_setter_behaviour, mocker):
""" """
Test that a message with an incorrect thread is ignored. Test that a message with an incorrect thread is ignored.
""" """
@@ -135,31 +138,31 @@ def test_process_belief_message_wrong_thread(belief_setter, mocker):
msg = create_mock_message( msg = create_mock_message(
sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs" sender_node=COLLECTOR_AGENT_JID, body='{"some": "data"}', thread="not_beliefs"
) )
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_not_called() mock_set_beliefs.assert_not_called()
def test_process_belief_message_empty_body(belief_setter, mocker): def test_process_belief_message_empty_body(belief_setter_behaviour, mocker):
""" """
Test that a message with an empty body is ignored. Test that a message with an empty body is ignored.
""" """
# Arrange # Arrange
msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs") msg = create_mock_message(sender_node=COLLECTOR_AGENT_JID, body="", thread="beliefs")
mock_set_beliefs = mocker.patch.object(belief_setter, "_set_beliefs") mock_set_beliefs = mocker.patch.object(belief_setter_behaviour, "_set_beliefs")
# Act # Act
belief_setter._process_belief_message(msg) belief_setter_behaviour._process_belief_message(msg)
# Assert # Assert
mock_set_beliefs.assert_not_called() mock_set_beliefs.assert_not_called()
def test_set_beliefs_success(belief_setter, mock_agent, caplog): def test_set_beliefs_success(belief_setter_behaviour, mock_agent, caplog):
""" """
Test that beliefs are correctly set on the agent's BDI. Test that beliefs are correctly set on the agent's BDI.
""" """
@@ -171,7 +174,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
# Act # Act
with caplog.at_level(logging.INFO): with caplog.at_level(logging.INFO):
belief_setter._set_beliefs(beliefs_to_set) belief_setter_behaviour._set_beliefs(beliefs_to_set)
# Assert # Assert
expected_calls = [ expected_calls = [
@@ -182,18 +185,18 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
assert mock_agent.bdi.set_belief.call_count == 2 assert mock_agent.bdi.set_belief.call_count == 2
# def test_responded_unset(belief_setter, mock_agent): # def test_responded_unset(belief_setter_behaviour, mock_agent):
# # Arrange # # Arrange
# new_beliefs = {"user_said": ["message"]} # new_beliefs = {"user_said": ["message"]}
# #
# # Act # # Act
# belief_setter._set_beliefs(new_beliefs) # belief_setter_behaviour._set_beliefs(new_beliefs)
# #
# # Assert # # Assert
# mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")]) # mock_agent.bdi.set_belief.assert_has_calls([call("user_said", "message")])
# mock_agent.bdi.remove_belief.assert_has_calls([call("responded")]) # mock_agent.bdi.remove_belief.assert_has_calls([call("responded")])
# def test_set_beliefs_bdi_not_initialized(belief_setter, mock_agent, caplog): # def test_set_beliefs_bdi_not_initialized(belief_setter_behaviour, mock_agent, caplog):
# """ # """
# Test that a warning is logged if the agent's BDI is not initialized. # Test that a warning is logged if the agent's BDI is not initialized.
# """ # """
@@ -203,7 +206,7 @@ def test_set_beliefs_success(belief_setter, mock_agent, caplog):
# #
# # Act # # Act
# with caplog.at_level(logging.WARNING): # with caplog.at_level(logging.WARNING):
# belief_setter._set_beliefs(beliefs_to_set) # belief_setter_behaviour._set_beliefs(beliefs_to_set)
# #
# # Assert # # Assert
# assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text # assert "Cannot set beliefs, since agent's BDI is not yet initialized." in caplog.text

View File

@@ -0,0 +1,101 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.bdi.belief_collector_agent.behaviours.belief_collector_behaviour import ( # noqa: E501
BeliefCollectorBehaviour,
)
def create_mock_message(sender_node: str, body: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
return msg
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock Agent."""
agent = MagicMock()
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def bel_collector_behaviouror(mock_agent, mocker):
"""Fixture to create an instance of BelCollectorBehaviour with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = BeliefCollectorBehaviour()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_message_received(bel_collector_behaviouror, mocker):
"""
Test that when a message is received, _process_message is called with that message.
"""
# Arrange
mock_msg = MagicMock()
bel_collector_behaviouror.receive.return_value = mock_msg
mocker.patch.object(bel_collector_behaviouror, "_process_message")
# Act
await bel_collector_behaviouror.run()
# Assert
bel_collector_behaviouror._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(bel_collector_behaviouror, mocker):
msg = create_mock_message(
"anyone",
json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}),
)
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_belief_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(bel_collector_behaviouror, mocker):
msg = create_mock_message(
"bel_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
)
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_belief_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(bel_collector_behaviouror, mocker):
msg = create_mock_message("anyone", json.dumps({"type": "emotion_extraction_text"}))
spy = mocker.patch.object(bel_collector_behaviouror, "_handle_emo_text", new=AsyncMock())
await bel_collector_behaviouror._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_happy_path_sends(bel_collector_behaviouror, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
bel_collector_behaviouror.send = AsyncMock()
await bel_collector_behaviouror._handle_belief_text(payload, "bel_text_agent_mock")
# make sure we attempted a send
bel_collector_behaviouror.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(bel_collector_behaviouror, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
bel_collector_behaviouror.send = AsyncMock()
await bel_collector_behaviouror._handle_belief_text(payload, "origin")
bel_collector_behaviouror.send.assert_awaited_once()

View File

@@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from spade.message import Message from spade.message import Message
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText from control_backend.agents.bdi.text_belief_extractor_agent.behaviours.text_belief_extractor_behaviour import ( # noqa: E501, We can't shorten this import.
TextBeliefExtractorBehaviour,
)
@pytest.fixture @pytest.fixture
@@ -15,15 +17,17 @@ def mock_settings():
""" """
# Create a mock object that mimics the nested structure # Create a mock object that mimics the nested structure
settings_mock = MagicMock() settings_mock = MagicMock()
settings_mock.agent_settings.transcription_agent_name = "transcriber" settings_mock.agent_settings.transcription_name = "transcriber"
settings_mock.agent_settings.belief_collector_agent_name = "collector" settings_mock.agent_settings.bdi_belief_collector_name = "collector"
settings_mock.agent_settings.host = "fake.host" settings_mock.agent_settings.host = "fake.host"
# Use patch to replace the settings object during the test # Use patch to replace the settings object during the test
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where # Adjust 'control_backend.behaviours.belief_from_text.settings' to where
# your behaviour file imports it from. # your behaviour file imports it from.
with patch( with patch(
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock "control_backend.agents.bdi.text_belief_extractor_agent.behaviours"
".text_belief_extractor_behaviour.settings",
settings_mock,
): ):
yield settings_mock yield settings_mock
@@ -31,10 +35,10 @@ def mock_settings():
@pytest.fixture @pytest.fixture
def behavior(mock_settings): def behavior(mock_settings):
""" """
Creates an instance of the BeliefFromText behaviour and mocks its Creates an instance of the BDITextBeliefBehaviour behaviour and mocks its
agent, logger, send, and receive methods. agent, logger, send, and receive methods.
""" """
b = BeliefFromText() b = TextBeliefExtractorBehaviour()
b.agent = MagicMock() b.agent = MagicMock()
b.send = AsyncMock() b.send = AsyncMock()
@@ -100,7 +104,7 @@ async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkey
# Arrange: Create a mock message from the transcriber # Arrange: Create a mock message from the transcriber
transcription_text = "hello world" transcription_text = "hello world"
mock_msg = create_mock_message( mock_msg = create_mock_message(
mock_settings.agent_settings.transcription_agent_name, transcription_text, None mock_settings.agent_settings.transcription_name, transcription_text, None
) )
behavior.receive.return_value = mock_msg behavior.receive.return_value = mock_msg
@@ -119,7 +123,7 @@ async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkey
assert ( assert (
sent_msg.to sent_msg.to
== mock_settings.agent_settings.belief_collector_agent_name == mock_settings.agent_settings.bdi_belief_collector_name
+ "@" + "@"
+ mock_settings.agent_settings.host + mock_settings.agent_settings.host
) )
@@ -159,7 +163,7 @@ async def test_process_transcription_success(behavior, mock_settings):
# 2. Inspect the sent message # 2. Inspect the sent message
sent_msg: Message = behavior.send.call_args[0][0] sent_msg: Message = behavior.send.call_args[0][0]
expected_to = ( expected_to = (
mock_settings.agent_settings.belief_collector_agent_name mock_settings.agent_settings.bdi_belief_collector_name
+ "@" + "@"
+ mock_settings.agent_settings.host + mock_settings.agent_settings.host
) )

View File

@@ -1,101 +0,0 @@
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from control_backend.agents.belief_collector.behaviours.continuous_collect import (
ContinuousBeliefCollector,
)
def create_mock_message(sender_node: str, body: str) -> MagicMock:
"""Helper function to create a configured mock message."""
msg = MagicMock()
msg.sender.node = sender_node # MagicMock automatically creates nested mocks
msg.body = body
return msg
@pytest.fixture
def mock_agent(mocker):
"""Fixture to create a mock Agent."""
agent = MagicMock()
agent.jid = "belief_collector_agent@test"
return agent
@pytest.fixture
def continuous_collector(mock_agent, mocker):
"""Fixture to create an instance of ContinuousBeliefCollector with a mocked agent."""
# Patch asyncio.sleep to prevent tests from actually waiting
mocker.patch("asyncio.sleep", return_value=None)
collector = ContinuousBeliefCollector()
collector.agent = mock_agent
# Mock the receive method, we will control its return value in each test
collector.receive = AsyncMock()
return collector
@pytest.mark.asyncio
async def test_run_message_received(continuous_collector, mocker):
"""
Test that when a message is received, _process_message is called with that message.
"""
# Arrange
mock_msg = MagicMock()
continuous_collector.receive.return_value = mock_msg
mocker.patch.object(continuous_collector, "_process_message")
# Act
await continuous_collector.run()
# Assert
continuous_collector._process_message.assert_awaited_once_with(mock_msg)
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker):
msg = create_mock_message(
"anyone",
json.dumps({"type": "belief_extraction_text", "beliefs": {"user_said": [["hi"]]}}),
)
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
msg = create_mock_message(
"belief_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
)
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_routes_to_handle_emo_text(continuous_collector, mocker):
msg = create_mock_message("anyone", json.dumps({"type": "emotion_extraction_text"}))
spy = mocker.patch.object(continuous_collector, "_handle_emo_text", new=AsyncMock())
await continuous_collector._process_message(msg)
spy.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_happy_path_sends(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock")
# make sure we attempted a send
continuous_collector.send.assert_awaited_once()
@pytest.mark.asyncio
async def test_belief_text_coerces_non_strings(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": [["hi", 123]]}}
continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "origin")
continuous_collector.send.assert_awaited_once()

View File

@@ -1,7 +1,7 @@
import numpy as np import numpy as np
import pytest import pytest
from control_backend.agents.transcription.speech_recognizer import ( from control_backend.agents.perception.transcription_agent.speech_recognizer import (
OpenAIWhisperSpeechRecognizer, OpenAIWhisperSpeechRecognizer,
SpeechRecognizer, SpeechRecognizer,
) )
@@ -10,7 +10,7 @@ from control_backend.agents.transcription.speech_recognizer import (
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_sr_settings(monkeypatch): def patch_sr_settings(monkeypatch):
# Patch the *module-local* settings that SpeechRecognizer imported # Patch the *module-local* settings that SpeechRecognizer imported
from control_backend.agents.transcription import speech_recognizer as sr from control_backend.agents.perception.transcription_agent import speech_recognizer as sr
# Provide real numbers for everything _estimate_max_tokens() reads # Provide real numbers for everything _estimate_max_tokens() reads
monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False) monkeypatch.setattr(sr.settings.vad_settings, "sample_rate_hz", 16_000, raising=False)

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
import zmq import zmq
from control_backend.agents.vad_agent import SocketPoller from control_backend.agents.perception.vad_agent import SocketPoller
@pytest.fixture @pytest.fixture
@@ -16,7 +16,7 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test" socket_data = b"test"
socket.recv.return_value = socket_data socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket) poller = SocketPoller(socket)
@@ -35,7 +35,7 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker): async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch("control_backend.agents.perception.vad_agent.zmq.Poller")
mock_poller.return_value.poll.return_value = [] mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket) poller = SocketPoller(socket)

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np import numpy as np
import pytest import pytest
from control_backend.agents.vad_agent import Streaming from control_backend.agents.perception.vad_agent import StreamingBehaviour
@pytest.fixture @pytest.fixture
@@ -29,7 +29,7 @@ def streaming(audio_in_socket, audio_out_socket, mock_agent):
import torch import torch
torch.hub.load.return_value = (..., ...) # Mock torch.hub.load.return_value = (..., ...) # Mock
streaming = Streaming(audio_in_socket, audio_out_socket) streaming = StreamingBehaviour(audio_in_socket, audio_out_socket)
streaming._ready = True streaming._ready = True
streaming.agent = mock_agent streaming.agent = mock_agent
return streaming return streaming
@@ -38,7 +38,7 @@ def streaming(audio_in_socket, audio_out_socket, mock_agent):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def patch_settings(monkeypatch): def patch_settings(monkeypatch):
# Patch the settings that vad_agent.run() reads # Patch the settings that vad_agent.run() reads
from control_backend.agents import vad_agent from control_backend.agents.perception import vad_agent
monkeypatch.setattr( monkeypatch.setattr(
vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False vad_agent.settings.behaviour_settings, "vad_prob_threshold", 0.5, raising=False