refactor: rename all agents and improve structure pt1

ref: N25B-257
This commit is contained in:
Björn Otgaar
2025-11-12 11:04:49 +01:00
parent 781a05328f
commit 0e45383027
37 changed files with 199 additions and 201 deletions

View File

@@ -1,7 +1 @@
from .base import BaseAgent as BaseAgent from .base import BaseAgent as BaseAgent
from .belief_collector.belief_collector import BeliefCollectorAgent as BeliefCollectorAgent
from .llm.llm import LLMAgent as LLMAgent
from .ri_command_agent import RICommandAgent as RICommandAgent
from .ri_communication_agent import RICommunicationAgent as RICommunicationAgent
from .transcription.transcription_agent import TranscriptionAgent as TranscriptionAgent
from .vad_agent import VADAgent as VADAgent

View File

@@ -0,0 +1 @@
from .act_speech_agent import ActSpeechAgent as ActSpeechAgent

View File

@@ -10,7 +10,7 @@ from control_backend.core.config import settings
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
class RICommandAgent(BaseAgent): class ActSpeechAgent(BaseAgent):
subsocket: zmq.Socket subsocket: zmq.Socket
pubsocket: zmq.Socket pubsocket: zmq.Socket
address = "" address = ""

View File

@@ -1,2 +0,0 @@
from .bdi_core import BDICoreAgent as BDICoreAgent
from .text_extractor import TBeliefExtractorAgent as TBeliefExtractorAgent

View File

@@ -0,0 +1 @@
from .bdi_core_agent import BDICoreAgent as BDICoreAgent

View File

@@ -32,7 +32,7 @@ class BeliefSetterBehaviour(CyclicBehaviour):
self.agent.logger.debug("Processing message from sender: %s", sender) self.agent.logger.debug("Processing message from sender: %s", sender)
match sender: match sender:
case settings.agent_settings.belief_collector_agent_name: case settings.agent_settings.bel_collector_agent_name:
self.agent.logger.debug( self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message." "Message is from the belief collector agent. Processing as belief message."
) )

View File

@@ -22,7 +22,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
speech_command = SpeechCommand(data=content) speech_command = SpeechCommand(data=content)
message = Message( message = Message(
to=settings.agent_settings.ri_command_agent_name to=settings.agent_settings.act_speech_agent_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
sender=self.agent.jid, sender=self.agent.jid,

View File

@@ -0,0 +1 @@
from .bel_collector_agent import BelCollectorAgent as BelCollectorAgent

View File

@@ -35,7 +35,7 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg_type = payload.get("type") msg_type = payload.get("type")
# Prefer explicit 'type' field # Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": if msg_type == "belief_extraction_text" or sender_node == "bel_text_agent_mock":
self.agent.logger.debug( self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node "Message routed to _handle_belief_text (sender=%s)", sender_node
) )
@@ -83,7 +83,9 @@ class ContinuousBeliefCollector(CyclicBehaviour):
if not beliefs: if not beliefs:
return return
to_jid = f"{settings.agent_settings.bdi_core_agent_name}@{settings.agent_settings.host}" to_jid = (
f"{settings.agent_settings.bdi_core_agent_agent_name}@{settings.agent_settings.host}"
)
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs) msg.body = json.dumps(beliefs)

View File

@@ -3,9 +3,9 @@ from control_backend.agents.base import BaseAgent
from .behaviours.continuous_collect import ContinuousBeliefCollector from .behaviours.continuous_collect import ContinuousBeliefCollector
class BeliefCollectorAgent(BaseAgent): class BelCollectorAgent(BaseAgent):
async def setup(self): async def setup(self):
self.logger.info("BeliefCollectorAgent starting (%s)", self.jid) self.logger.info("BelCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI) # Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector()) self.add_behaviour(ContinuousBeliefCollector())
self.logger.info("BeliefCollectorAgent ready.") self.logger.info("BelCollectorAgent ready.")

View File

@@ -44,7 +44,7 @@ class BeliefFromText(CyclicBehaviour):
sender = msg.sender.node sender = msg.sender.node
match sender: match sender:
case settings.agent_settings.transcription_agent_name: case settings.agent_settings.per_transcription_agent_name:
self.logger.debug("Received text from transcriber: %s", msg.body) self.logger.debug("Received text from transcriber: %s", msg.body)
await self._process_transcription_demo(msg.body) await self._process_transcription_demo(msg.body)
case _: case _:
@@ -71,7 +71,7 @@ class BeliefFromText(CyclicBehaviour):
belief_message = Message() belief_message = Message()
belief_message.to = ( belief_message.to = (
settings.agent_settings.belief_collector_agent_name settings.agent_settings.bel_collector_agent_name
+ "@" + "@"
+ settings.agent_settings.host + settings.agent_settings.host
) )
@@ -95,7 +95,7 @@ class BeliefFromText(CyclicBehaviour):
belief_msg = Message() belief_msg = Message()
belief_msg.to = ( belief_msg.to = (
settings.agent_settings.belief_collector_agent_name + "@" + settings.agent_settings.host settings.agent_settings.bel_collector_agent_name + "@" + settings.agent_settings.host
) )
belief_msg.body = payload belief_msg.body = payload
belief_msg.thread = "beliefs" belief_msg.thread = "beliefs"

View File

@@ -3,6 +3,6 @@ from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor import BeliefFromText from .behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractorAgent(BaseAgent): class BelTextExtractAgent(BaseAgent):
async def setup(self): async def setup(self):
self.add_behaviour(BeliefFromText()) self.add_behaviour(BeliefFromText())

View File

@@ -0,0 +1 @@
from .com_ri_agent import ComRIAgent as ComRIAgent

View File

@@ -7,10 +7,10 @@ from zmq.asyncio import Context
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .ri_command_agent import RICommandAgent from ..act_agents.act_speech_agent import ActSpeechAgent
class RICommunicationAgent(BaseAgent): class ComRIAgent(BaseAgent):
req_socket: zmq.Socket req_socket: zmq.Socket
_address = "" _address = ""
_bind = True _bind = True
@@ -132,11 +132,11 @@ class RICommunicationAgent(BaseAgent):
else: else:
self.req_socket.bind(addr) self.req_socket.bind(addr)
case "actuation": case "actuation":
ri_commands_agent = RICommandAgent( ri_commands_agent = ActSpeechAgent(
settings.agent_settings.ri_command_agent_name settings.agent_settings.act_speech_agent_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
settings.agent_settings.ri_command_agent_name, settings.agent_settings.act_speech_agent_name,
address=addr, address=addr,
bind=bind, bind=bind,
) )
@@ -153,7 +153,7 @@ class RICommunicationAgent(BaseAgent):
break break
else: else:
self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries) self.logger.error("Failed to set up ComRIAgent after %d retries", max_retries)
return return
# Set up ping behaviour # Set up ping behaviour

View File

@@ -0,0 +1 @@
from .llm_agent import LLMAgent as LLMAgent

View File

@@ -39,7 +39,7 @@ class LLMAgent(BaseAgent):
sender, sender,
) )
if sender == settings.agent_settings.bdi_core_agent_name: if sender == settings.agent_settings.bdi_core_agent_agent_name:
self.agent.logger.debug("Processing message from BDI Core Agent") self.agent.logger.debug("Processing message from BDI Core Agent")
await self._process_bdi_message(msg) await self._process_bdi_message(msg)
else: else:
@@ -63,7 +63,9 @@ class LLMAgent(BaseAgent):
Sends a response message back to the BDI Core Agent. Sends a response message back to the BDI Core Agent.
""" """
reply = Message( reply = Message(
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host, to=settings.agent_settings.bdi_core_agent_agent_name
+ "@"
+ settings.agent_settings.host,
body=msg, body=msg,
) )
await self.send(reply) await self.send(reply)

View File

@@ -7,11 +7,11 @@ from spade.message import Message
from control_backend.core.config import settings from control_backend.core.config import settings
class BeliefTextAgent(Agent): class BelTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour): class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self): async def run(self):
to_jid = ( to_jid = (
settings.agent_settings.belief_collector_agent_name settings.agent_settings.bel_collector_agent_name
+ "@" + "@"
+ settings.agent_settings.host + settings.agent_settings.host
) )
@@ -39,6 +39,6 @@ class BeliefTextAgent(Agent):
await self.agent.stop() await self.agent.stop()
async def setup(self): async def setup(self):
print("BeliefTextAgent started") print("BelTextAgent started")
self.b = self.SendOnceBehaviourBlfText() self.b = self.SendOnceBehaviourBlfText()
self.add_behaviour(self.b) self.add_behaviour(self.b)

View File

@@ -0,0 +1,4 @@
from .per_transcription_agent.per_transcription_agent import (
PerTranscriptionAgent as PerTranscriptionAgent,
)
from .per_vad_agent import PerVADAgent as PerVADAgent

View File

@@ -12,15 +12,19 @@ from control_backend.core.config import settings
from .speech_recognizer import SpeechRecognizer from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(BaseAgent): class PerTranscriptionAgent(BaseAgent):
""" """
An agent which listens to audio fragments with voice, transcribes them, and sends the An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents. transcription to other agents.
""" """
def __init__(self, audio_in_address: str): def __init__(self, audio_in_address: str):
jid = settings.agent_settings.transcription_agent_name + "@" + settings.agent_settings.host jid = (
super().__init__(jid, settings.agent_settings.transcription_agent_name) settings.agent_settings.per_transcription_agent_name
+ "@"
+ settings.agent_settings.host
)
super().__init__(jid, settings.agent_settings.per_transcription_agent_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_socket: azmq.Socket | None = None self.audio_in_socket: azmq.Socket | None = None
@@ -43,7 +47,7 @@ class TranscriptionAgent(BaseAgent):
async def _share_transcription(self, transcription: str): async def _share_transcription(self, transcription: str):
"""Share a transcription to the other agents that depend on it.""" """Share a transcription to the other agents that depend on it."""
receiver_jids = [ receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name settings.agent_settings.texbel_text_extractor_agent_name
+ "@" + "@"
+ settings.agent_settings.host, + settings.agent_settings.host,
] # Set message receivers here ] # Set message receivers here

View File

@@ -7,7 +7,7 @@ from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .transcription.transcription_agent import TranscriptionAgent from .per_transcription_agent.per_transcription_agent import PerTranscriptionAgent
class SocketPoller[T]: class SocketPoller[T]:
@@ -102,15 +102,15 @@ class Streaming(CyclicBehaviour):
self.audio_buffer = chunk self.audio_buffer = chunk
class VADAgent(BaseAgent): class PerVADAgent(BaseAgent):
""" """
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ. fragments with detected speech to other agents over ZeroMQ.
""" """
def __init__(self, audio_in_address: str, audio_in_bind: bool): def __init__(self, audio_in_address: str, audio_in_bind: bool):
jid = settings.agent_settings.vad_agent_name + "@" + settings.agent_settings.host jid = settings.agent_settings.per_vad_agent_name + "@" + settings.agent_settings.host
super().__init__(jid, settings.agent_settings.vad_agent_name) super().__init__(jid, settings.agent_settings.per_vad_agent_name)
self.audio_in_address = audio_in_address self.audio_in_address = audio_in_address
self.audio_in_bind = audio_in_bind self.audio_in_bind = audio_in_bind
@@ -166,7 +166,7 @@ class VADAgent(BaseAgent):
self.add_behaviour(self.streaming_behaviour) self.add_behaviour(self.streaming_behaviour)
# Start agents dependent on the output audio fragments here # Start agents dependent on the output audio fragments here
transcriber = TranscriptionAgent(audio_out_address) transcriber = PerTranscriptionAgent(audio_out_address)
await transcriber.start() await transcriber.start()
self.logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)

View File

@@ -9,16 +9,16 @@ class ZMQSettings(BaseModel):
class AgentSettings(BaseModel): class AgentSettings(BaseModel):
host: str = "localhost" host: str = "localhost"
bdi_core_agent_name: str = "bdi_core" bdi_core_agent_agent_name: str = "bdi_core_agent"
belief_collector_agent_name: str = "belief_collector" bel_collector_agent_name: str = "bel_collector_agent"
text_belief_extractor_agent_name: str = "text_belief_extractor" texbel_text_extractor_agent_name: str = "text_belief_extractor"
vad_agent_name: str = "vad_agent" per_vad_agent_name: str = "per_vad_agent"
llm_agent_name: str = "llm_agent" llm_agent_name: str = "llm_agent"
test_agent_name: str = "test_agent" test_agent_name: str = "test_agent"
transcription_agent_name: str = "transcription_agent" per_transcription_agent_name: str = "per_transcription_agent"
ri_communication_agent_name: str = "ri_communication_agent" com_ri_agent_name: str = "com_ri_agent"
ri_command_agent_name: str = "ri_command_agent" act_speech_agent_name: str = "act_speech_agent"
class LLMSettings(BaseModel): class LLMSettings(BaseModel):

View File

@@ -8,12 +8,12 @@ from fastapi.middleware.cors import CORSMiddleware
from zmq.asyncio import Context from zmq.asyncio import Context
from control_backend.agents import ( from control_backend.agents import (
BeliefCollectorAgent, BelCollectorAgent,
ComRIAgent,
LLMAgent, LLMAgent,
RICommunicationAgent, PerVADAgent,
VADAgent,
) )
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent from control_backend.agents.bdi_agents import BDICoreAgent, BelTextExtractAgent
from control_backend.api.v1.router import api_router from control_backend.api.v1.router import api_router
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.logging import setup_logging from control_backend.logging import setup_logging
@@ -64,13 +64,13 @@ async def lifespan(app: FastAPI):
# --- Initialize Agents --- # --- Initialize Agents ---
logger.info("Initializing and starting agents.") logger.info("Initializing and starting agents.")
agents_to_start = { agents_to_start = {
"RICommunicationAgent": ( "ComRIAgent": (
RICommunicationAgent, ComRIAgent,
{ {
"name": settings.agent_settings.ri_communication_agent_name, "name": settings.agent_settings.com_ri_agent_name,
"jid": f"{settings.agent_settings.ri_communication_agent_name}" "jid": f"{settings.agent_settings.com_ri_agent_name}"
f"@{settings.agent_settings.host}", f"@{settings.agent_settings.host}",
"password": settings.agent_settings.ri_communication_agent_name, "password": settings.agent_settings.com_ri_agent_name,
"address": "tcp://*:5555", "address": "tcp://*:5555",
"bind": True, "bind": True,
}, },
@@ -86,33 +86,33 @@ async def lifespan(app: FastAPI):
"BDICoreAgent": ( "BDICoreAgent": (
BDICoreAgent, BDICoreAgent,
{ {
"name": settings.agent_settings.bdi_core_agent_name, "name": settings.agent_settings.bdi_core_agent_agent_name,
"jid": f"{settings.agent_settings.bdi_core_agent_name}@" "jid": f"{settings.agent_settings.bdi_core_agent_agent_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.bdi_core_agent_name, "password": settings.agent_settings.bdi_core_agent_agent_name,
"asl": "src/control_backend/agents/bdi/rules.asl", "asl": "src/control_backend/agents/bdi/rules.asl",
}, },
), ),
"BeliefCollectorAgent": ( "BelCollectorAgent": (
BeliefCollectorAgent, BelCollectorAgent,
{ {
"name": settings.agent_settings.belief_collector_agent_name, "name": settings.agent_settings.bel_collector_agent_name,
"jid": f"{settings.agent_settings.belief_collector_agent_name}@" "jid": f"{settings.agent_settings.bel_collector_agent_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.belief_collector_agent_name, "password": settings.agent_settings.bel_collector_agent_name,
}, },
), ),
"TBeliefExtractor": ( "TBeliefExtractor": (
TBeliefExtractorAgent, BelTextExtractAgent,
{ {
"name": settings.agent_settings.text_belief_extractor_agent_name, "name": settings.agent_settings.texbel_text_extractor_agent_name,
"jid": f"{settings.agent_settings.text_belief_extractor_agent_name}@" "jid": f"{settings.agent_settings.texbel_text_extractor_agent_name}@"
f"{settings.agent_settings.host}", f"{settings.agent_settings.host}",
"password": settings.agent_settings.text_belief_extractor_agent_name, "password": settings.agent_settings.texbel_text_extractor_agent_name,
}, },
), ),
"VADAgent": ( "PerVADAgent": (
VADAgent, PerVADAgent,
{"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False}, {"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False},
), ),
} }

View File

@@ -5,43 +5,47 @@ import pytest
import zmq import zmq
from spade.agent import Agent from spade.agent import Agent
from control_backend.agents.vad_agent import VADAgent from control_backend.agents.per_agents.per_vad_agent import PerVADAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.per_agents.per_vad_agent.azmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@pytest.fixture @pytest.fixture
def streaming(mocker): def streaming(mocker):
return mocker.patch("control_backend.agents.vad_agent.Streaming") return mocker.patch("control_backend.agents.per_agents.per_vad_agent.Streaming")
@pytest.fixture @pytest.fixture
def transcription_agent(mocker): def per_transcription_agent(mocker):
return mocker.patch("control_backend.agents.vad_agent.TranscriptionAgent", autospec=True) return mocker.patch(
"control_backend.agents.per_agents.per_vad_agent.PerTranscriptionAgent", autospec=True
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_normal_setup(streaming, transcription_agent): async def test_normal_setup(streaming, per_transcription_agent):
""" """
Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio Test that during normal setup, the VAD agent creates a Streaming behavior and creates audio
sockets, and starts the TranscriptionAgent without loading real models. sockets, and starts the PerTranscriptionAgent without loading real models.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = PerVADAgent("tcp://localhost:12345", False)
vad_agent.add_behaviour = MagicMock() per_vad_agent.add_behaviour = MagicMock()
await vad_agent.setup() await per_vad_agent.setup()
streaming.assert_called_once() streaming.assert_called_once()
vad_agent.add_behaviour.assert_called_once_with(streaming.return_value) per_vad_agent.add_behaviour.assert_called_once_with(streaming.return_value)
transcription_agent.assert_called_once() per_transcription_agent.assert_called_once()
transcription_agent.return_value.start.assert_called_once() per_transcription_agent.return_value.start.assert_called_once()
assert vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
assert vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
@pytest.mark.parametrize("do_bind", [True, False]) @pytest.mark.parametrize("do_bind", [True, False])
@@ -50,11 +54,11 @@ def test_in_socket_creation(zmq_context, do_bind: bool):
Test that the VAD agent creates an audio input socket, differentiating between binding and Test that the VAD agent creates an audio input socket, differentiating between binding and
connecting. connecting.
""" """
vad_agent = VADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind) per_vad_agent = PerVADAgent(f"tcp://{'*' if do_bind else 'localhost'}:12345", do_bind)
vad_agent._connect_audio_in_socket() per_vad_agent._connect_audio_in_socket()
assert vad_agent.audio_in_socket is not None assert per_vad_agent.audio_in_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.SUB) zmq_context.return_value.socket.assert_called_once_with(zmq.SUB)
zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with( zmq_context.return_value.socket.return_value.setsockopt_string.assert_called_once_with(
@@ -74,11 +78,11 @@ def test_out_socket_creation(zmq_context):
""" """
Test that the VAD agent creates an audio output socket correctly. Test that the VAD agent creates an audio output socket correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = PerVADAgent("tcp://localhost:12345", False)
vad_agent._connect_audio_out_socket() per_vad_agent._connect_audio_out_socket()
assert vad_agent.audio_out_socket is not None assert per_vad_agent.audio_out_socket is not None
zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) zmq_context.return_value.socket.assert_called_once_with(zmq.PUB)
zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once()
@@ -93,28 +97,28 @@ async def test_out_socket_creation_failure(zmq_context):
zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = ( zmq_context.return_value.socket.return_value.bind_to_random_port.side_effect = (
zmq.ZMQBindError zmq.ZMQBindError
) )
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = PerVADAgent("tcp://localhost:12345", False)
await vad_agent.setup() await per_vad_agent.setup()
assert vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None
mock_super_stop.assert_called_once() mock_super_stop.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop(zmq_context, transcription_agent): async def test_stop(zmq_context, per_transcription_agent):
""" """
Test that when the VAD agent is stopped, the sockets are closed correctly. Test that when the VAD agent is stopped, the sockets are closed correctly.
""" """
vad_agent = VADAgent("tcp://localhost:12345", False) per_vad_agent = PerVADAgent("tcp://localhost:12345", False)
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint( zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
1000, 1000,
10000, 10000,
) )
await vad_agent.setup() await per_vad_agent.setup()
await vad_agent.stop() await per_vad_agent.stop()
assert zmq_context.return_value.socket.return_value.close.call_count == 2 assert zmq_context.return_value.socket.return_value.close.call_count == 2
assert vad_agent.audio_in_socket is None assert per_vad_agent.audio_in_socket is None
assert vad_agent.audio_out_socket is None assert per_vad_agent.audio_out_socket is None

View File

@@ -5,7 +5,7 @@ import pytest
import soundfile as sf import soundfile as sf
import zmq import zmq
from control_backend.agents.vad_agent import Streaming from control_backend.agents.per_agents.per_vad_agent import Streaming
def get_audio_chunks() -> list[bytes]: def get_audio_chunks() -> list[bytes]:
@@ -42,7 +42,9 @@ async def test_real_audio(mocker):
audio_in_socket = AsyncMock() audio_in_socket = AsyncMock()
audio_in_socket.recv.side_effect = audio_chunks audio_in_socket.recv.side_effect = audio_chunks
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch(
"control_backend.agents.per_agents.per_vad_agent.zmq.Poller"
)
mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)] mock_poller.return_value.poll.return_value = [(audio_in_socket, zmq.POLLIN)]
audio_out_socket = AsyncMock() audio_out_socket = AsyncMock()

View File

@@ -4,12 +4,14 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import zmq import zmq
from control_backend.agents.ri_command_agent import RICommandAgent from control_backend.agents.act_agents.act_speech_agent import ActSpeechAgent
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.act_agents.act_speech_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@@ -19,8 +21,8 @@ async def test_setup_bind(zmq_context, mocker):
"""Test setup with bind=True""" """Test setup with bind=True"""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=True) agent = ActSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings") settings = mocker.patch("control_backend.agents.act_agents.act_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup() await agent.setup()
@@ -40,8 +42,8 @@ async def test_setup_connect(zmq_context, mocker):
"""Test setup with bind=False""" """Test setup with bind=False"""
fake_socket = zmq_context.return_value.socket.return_value fake_socket = zmq_context.return_value.socket.return_value
agent = RICommandAgent("test@server", "password", address="tcp://localhost:5555", bind=False) agent = ActSpeechAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
settings = mocker.patch("control_backend.agents.ri_command_agent.settings") settings = mocker.patch("control_backend.agents.act_agents.act_speech_agent.settings")
settings.zmq_settings.internal_sub_address = "tcp://internal:1234" settings.zmq_settings.internal_sub_address = "tcp://internal:1234"
await agent.setup() await agent.setup()
@@ -60,14 +62,16 @@ async def test_send_commands_behaviour_valid_message():
) )
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = ActSpeechAgent("test@server", "password")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with patch("control_backend.agents.ri_command_agent.SpeechCommand") as MockSpeechCommand: with patch(
"control_backend.agents.act_agents.act_speech_agent.SpeechCommand"
) as MockSpeechCommand:
mock_message = MagicMock() mock_message = MagicMock()
MockSpeechCommand.model_validate.return_value = mock_message MockSpeechCommand.model_validate.return_value = mock_message
@@ -84,16 +88,14 @@ async def test_send_commands_behaviour_invalid_message(caplog):
fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}")) fake_socket.recv_multipart = AsyncMock(return_value=(b"command", b"{invalid_json}"))
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
agent = RICommandAgent("test@server", "password") agent = ActSpeechAgent("test@server", "password")
agent.subsocket = fake_socket agent.subsocket = fake_socket
agent.pubsocket = fake_socket agent.pubsocket = fake_socket
behaviour = agent.SendCommandsBehaviour() behaviour = agent.SendCommandsBehaviour()
behaviour.agent = agent behaviour.agent = agent
with caplog.at_level("ERROR"):
await behaviour.run() await behaviour.run()
fake_socket.recv_multipart.assert_awaited() fake_socket.recv_multipart.assert_awaited()
fake_socket.send_json.assert_not_awaited() fake_socket.send_json.assert_not_awaited()
assert "Error processing message" in caplog.text

View File

@@ -3,7 +3,11 @@ from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest import pytest
from control_backend.agents.ri_communication_agent import RICommunicationAgent from control_backend.agents.com_agents.com_ri_agent import ComRIAgent
def act_agent_path():
return "control_backend.agents.com_agents.com_ri_agent.ActSpeechAgent"
def fake_json_correct_negototiate_1(): def fake_json_correct_negototiate_1():
@@ -86,7 +90,9 @@ def fake_json_invalid_id_negototiate():
@pytest.fixture @pytest.fixture
def zmq_context(mocker): def zmq_context(mocker):
mock_context = mocker.patch("control_backend.agents.vad_agent.azmq.Context.instance") mock_context = mocker.patch(
"control_backend.agents.com_agents.com_ri_agent.zmq.Context.instance"
)
mock_context.return_value = MagicMock() mock_context.return_value = MagicMock()
return mock_context return mock_context
@@ -101,17 +107,13 @@ async def test_setup_creates_socket_and_negotiate_1(zmq_context):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_1() fake_socket.recv_json = fake_json_correct_negototiate_1()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -139,17 +141,13 @@ async def test_setup_creates_socket_and_negotiate_2(zmq_context):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_2() fake_socket.recv_json = fake_json_correct_negototiate_2()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -177,19 +175,17 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_wrong_negototiate_1() fake_socket.recv_json = fake_json_wrong_negototiate_1()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("ERROR"): with caplog.at_level("ERROR"):
agent = RICommunicationAgent( agent = ComRIAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server", "password", address="tcp://localhost:5555", bind=False
) )
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
@@ -200,7 +196,7 @@ async def test_setup_creates_socket_and_negotiate_3(zmq_context, caplog):
# Since it failed, there should not be any command agent. # Since it failed, there should not be any command agent.
fake_agent_instance.start.assert_not_awaited() fake_agent_instance.start.assert_not_awaited()
assert "Failed to set up RICommunicationAgent" in caplog.text assert "Failed to set up ComRIAgent" in caplog.text
# Ensure the agent did not attach a ListenBehaviour # Ensure the agent did not attach a ListenBehaviour
assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours) assert not any(isinstance(b, agent.ListenBehaviour) for b in agent.behaviours)
@@ -216,17 +212,13 @@ async def test_setup_creates_socket_and_negotiate_4(zmq_context):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_3() fake_socket.recv_json = fake_json_correct_negototiate_3()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=True)
"test@server", "password", address="tcp://localhost:5555", bind=True
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -254,17 +246,13 @@ async def test_setup_creates_socket_and_negotiate_5(zmq_context):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_4() fake_socket.recv_json = fake_json_correct_negototiate_4()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -292,17 +280,13 @@ async def test_setup_creates_socket_and_negotiate_6(zmq_context):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_correct_negototiate_5() fake_socket.recv_json = fake_json_correct_negototiate_5()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
"test@server", "password", address="tcp://localhost:5555", bind=False
)
await agent.setup() await agent.setup()
# --- Assert --- # --- Assert ---
@@ -330,19 +314,17 @@ async def test_setup_creates_socket_and_negotiate_7(zmq_context, caplog):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = fake_json_invalid_id_negototiate() fake_socket.recv_json = fake_json_invalid_id_negototiate()
# Mock RICommandAgent agent startup # Mock ActSpeechAgent agent startup
# We are sending wrong negotiation info to the communication agent, # We are sending wrong negotiation info to the communication agent,
# so we should retry and expect a better response, within a limited time. # so we should retry and expect a better response, within a limited time.
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"): with caplog.at_level("WARNING"):
agent = RICommunicationAgent( agent = ComRIAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server", "password", address="tcp://localhost:5555", bind=False
) )
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
@@ -366,15 +348,13 @@ async def test_setup_creates_socket_and_negotiate_timeout(zmq_context, caplog):
fake_socket.send_json = AsyncMock() fake_socket.send_json = AsyncMock()
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
# --- Act --- # --- Act ---
with caplog.at_level("WARNING"): with caplog.at_level("WARNING"):
agent = RICommunicationAgent( agent = ComRIAgent(
"test@server", "password", address="tcp://localhost:5555", bind=False "test@server", "password", address="tcp://localhost:5555", bind=False
) )
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
@@ -397,7 +377,7 @@ async def test_listen_behaviour_ping_correct(caplog):
fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}}) fake_socket.recv_json = AsyncMock(return_value={"endpoint": "ping", "data": {}})
# TODO: Integration test between actual server and password needed for spade agents # TODO: Integration test between actual server and password needed for spade agents
agent = RICommunicationAgent("test@server", "password") agent = ComRIAgent("test@server", "password")
agent.req_socket = fake_socket agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
@@ -431,7 +411,7 @@ async def test_listen_behaviour_ping_wrong_endpoint(caplog):
} }
) )
agent = RICommunicationAgent("test@server", "password") agent = ComRIAgent("test@server", "password")
agent.req_socket = fake_socket agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
@@ -453,7 +433,7 @@ async def test_listen_behaviour_timeout(zmq_context, caplog):
# recv_json will never resolve, simulate timeout # recv_json will never resolve, simulate timeout
fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError) fake_socket.recv_json = AsyncMock(side_effect=asyncio.TimeoutError)
agent = RICommunicationAgent("test@server", "password") agent = ComRIAgent("test@server", "password")
agent.req_socket = fake_socket agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
@@ -480,7 +460,7 @@ async def test_listen_behaviour_ping_no_endpoint(caplog):
} }
) )
agent = RICommunicationAgent("test@server", "password") agent = ComRIAgent("test@server", "password")
agent.req_socket = fake_socket agent.req_socket = fake_socket
behaviour = agent.ListenBehaviour() behaviour = agent.ListenBehaviour()
@@ -502,9 +482,7 @@ async def test_setup_unexpected_exception(zmq_context, caplog):
# Simulate unexpected exception during recv_json() # Simulate unexpected exception during recv_json()
fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!")) fake_socket.recv_json = AsyncMock(side_effect=Exception("boom!"))
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
"test@server", "password", address="tcp://localhost:5555", bind=False
)
with caplog.at_level("ERROR"): with caplog.at_level("ERROR"):
await agent.setup(max_retries=1) await agent.setup(max_retries=1)
@@ -526,16 +504,12 @@ async def test_setup_unpacking_exception(zmq_context, caplog):
} # missing 'port' and 'bind' } # missing 'port' and 'bind'
fake_socket.recv_json = AsyncMock(return_value=malformed_data) fake_socket.recv_json = AsyncMock(return_value=malformed_data)
# Patch RICommandAgent so it won't actually start # Patch ActSpeechAgent so it won't actually start
with patch( with patch(act_agent_path(), autospec=True) as MockCommandAgent:
"control_backend.agents.ri_communication_agent.RICommandAgent", autospec=True
) as MockCommandAgent:
fake_agent_instance = MockCommandAgent.return_value fake_agent_instance = MockCommandAgent.return_value
fake_agent_instance.start = AsyncMock() fake_agent_instance.start = AsyncMock()
agent = RICommunicationAgent( agent = ComRIAgent("test@server", "password", address="tcp://localhost:5555", bind=False)
"test@server", "password", address="tcp://localhost:5555", bind=False
)
# --- Act & Assert --- # --- Act & Assert ---
with caplog.at_level("ERROR"): with caplog.at_level("ERROR"):

View File

@@ -4,10 +4,10 @@ from unittest.mock import AsyncMock, MagicMock, call
import pytest import pytest
from control_backend.agents.bdi.behaviours.belief_setter import BeliefSetterBehaviour from control_backend.agents.bdi_agents.behaviours.belief_setter import BeliefSetterBehaviour
# Define a constant for the collector agent name to use in tests # Define a constant for the collector agent name to use in tests
COLLECTOR_AGENT_NAME = "belief_collector" COLLECTOR_AGENT_NAME = "bel_collector_agent"
COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test" COLLECTOR_AGENT_JID = f"{COLLECTOR_AGENT_NAME}@test"
@@ -25,7 +25,7 @@ def belief_setter(mock_agent, mocker):
"""Fixture to create an instance of BeliefSetterBehaviour with a mocked agent.""" """Fixture to create an instance of BeliefSetterBehaviour with a mocked agent."""
# Patch the settings to use a predictable agent name # Patch the settings to use a predictable agent name
mocker.patch( mocker.patch(
"control_backend.agents.bdi.behaviours.belief_setter.settings.agent_settings.belief_collector_agent_name", "control_backend.agents.bdi_agents.behaviours.belief_setter.settings.agent_settings.bel_collector_agent_name",
COLLECTOR_AGENT_NAME, COLLECTOR_AGENT_NAME,
) )
@@ -62,7 +62,7 @@ async def test_run_message_received(belief_setter, mocker):
belief_setter._process_message.assert_called_once_with(msg) belief_setter._process_message.assert_called_once_with(msg)
def test_process_message_from_belief_collector(belief_setter, mocker): def test_process_message_from_bel_collector_agent(belief_setter, mocker):
""" """
Test processing a message from the correct belief collector agent. Test processing a message from the correct belief collector agent.
""" """

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from control_backend.agents.belief_collector.behaviours.continuous_collect import ( from control_backend.agents.bel_agents.bel_collector_agent.behaviours.continuous_collect import (
ContinuousBeliefCollector, ContinuousBeliefCollector,
) )
@@ -20,7 +20,7 @@ def create_mock_message(sender_node: str, body: str) -> MagicMock:
def mock_agent(mocker): def mock_agent(mocker):
"""Fixture to create a mock Agent.""" """Fixture to create a mock Agent."""
agent = MagicMock() agent = MagicMock()
agent.jid = "belief_collector_agent@test" agent.jid = "bel_collector_agent@test"
return agent return agent
@@ -68,7 +68,7 @@ async def test_routes_to_handle_belief_text_by_type(continuous_collector, mocker
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker): async def test_routes_to_handle_belief_text_by_sender(continuous_collector, mocker):
msg = create_mock_message( msg = create_mock_message(
"belief_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}}) "bel_text_agent_mock", json.dumps({"beliefs": {"user_said": [["hi"]]}})
) )
spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock()) spy = mocker.patch.object(continuous_collector, "_handle_belief_text", new=AsyncMock())
await continuous_collector._process_message(msg) await continuous_collector._process_message(msg)
@@ -87,7 +87,7 @@ async def test_routes_to_handle_emo_text(continuous_collector, mocker):
async def test_belief_text_happy_path_sends(continuous_collector, mocker): async def test_belief_text_happy_path_sends(continuous_collector, mocker):
payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}} payload = {"type": "belief_extraction_text", "beliefs": {"user_said": ["hello test", "No"]}}
continuous_collector.send = AsyncMock() continuous_collector.send = AsyncMock()
await continuous_collector._handle_belief_text(payload, "belief_text_agent_mock") await continuous_collector._handle_belief_text(payload, "bel_text_agent_mock")
# make sure we attempted a send # make sure we attempted a send
continuous_collector.send.assert_awaited_once() continuous_collector.send.assert_awaited_once()

View File

@@ -4,7 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from spade.message import Message from spade.message import Message
from control_backend.agents.bdi.behaviours.text_belief_extractor import BeliefFromText from control_backend.agents.bel_agents.bel_text_extract_agent.behaviours.text_belief_extractor import ( # noqa: E501, We can't shorten this import.
BeliefFromText,
)
@pytest.fixture @pytest.fixture
@@ -15,15 +17,16 @@ def mock_settings():
""" """
# Create a mock object that mimics the nested structure # Create a mock object that mimics the nested structure
settings_mock = MagicMock() settings_mock = MagicMock()
settings_mock.agent_settings.transcription_agent_name = "transcriber" settings_mock.agent_settings.per_transcription_agent_name = "transcriber"
settings_mock.agent_settings.belief_collector_agent_name = "collector" settings_mock.agent_settings.bel_collector_agent_name = "collector"
settings_mock.agent_settings.host = "fake.host" settings_mock.agent_settings.host = "fake.host"
# Use patch to replace the settings object during the test # Use patch to replace the settings object during the test
# Adjust 'control_backend.behaviours.belief_from_text.settings' to where # Adjust 'control_backend.behaviours.belief_from_text.settings' to where
# your behaviour file imports it from. # your behaviour file imports it from.
with patch( with patch(
"control_backend.agents.bdi.behaviours.text_belief_extractor.settings", settings_mock "control_backend.agents.bel_agents.bel_text_extract_agent.behaviours.text_belief_extractor.settings",
settings_mock,
): ):
yield settings_mock yield settings_mock
@@ -100,7 +103,7 @@ async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkey
# Arrange: Create a mock message from the transcriber # Arrange: Create a mock message from the transcriber
transcription_text = "hello world" transcription_text = "hello world"
mock_msg = create_mock_message( mock_msg = create_mock_message(
mock_settings.agent_settings.transcription_agent_name, transcription_text, None mock_settings.agent_settings.per_transcription_agent_name, transcription_text, None
) )
behavior.receive.return_value = mock_msg behavior.receive.return_value = mock_msg
@@ -119,7 +122,7 @@ async def test_run_message_from_transcriber_demo(behavior, mock_settings, monkey
assert ( assert (
sent_msg.to sent_msg.to
== mock_settings.agent_settings.belief_collector_agent_name == mock_settings.agent_settings.bel_collector_agent_name
+ "@" + "@"
+ mock_settings.agent_settings.host + mock_settings.agent_settings.host
) )
@@ -159,7 +162,7 @@ async def test_process_transcription_success(behavior, mock_settings):
# 2. Inspect the sent message # 2. Inspect the sent message
sent_msg: Message = behavior.send.call_args[0][0] sent_msg: Message = behavior.send.call_args[0][0]
expected_to = ( expected_to = (
mock_settings.agent_settings.belief_collector_agent_name mock_settings.agent_settings.bel_collector_agent_name
+ "@" + "@"
+ mock_settings.agent_settings.host + mock_settings.agent_settings.host
) )

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
import zmq import zmq
from control_backend.agents.vad_agent import SocketPoller from control_backend.agents.per_agents.per_vad_agent import SocketPoller
@pytest.fixture @pytest.fixture
@@ -16,7 +16,9 @@ async def test_socket_poller_with_data(socket, mocker):
socket_data = b"test" socket_data = b"test"
socket.recv.return_value = socket_data socket.recv.return_value = socket_data
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch(
"control_backend.agents.per_agents.per_vad_agent.zmq.Poller"
)
mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)] mock_poller.return_value.poll.return_value = [(socket, zmq.POLLIN)]
poller = SocketPoller(socket) poller = SocketPoller(socket)
@@ -35,7 +37,9 @@ async def test_socket_poller_with_data(socket, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_socket_poller_no_data(socket, mocker): async def test_socket_poller_no_data(socket, mocker):
mock_poller: MagicMock = mocker.patch("control_backend.agents.vad_agent.zmq.Poller") mock_poller: MagicMock = mocker.patch(
"control_backend.agents.per_agents.per_vad_agent.zmq.Poller"
)
mock_poller.return_value.poll.return_value = [] mock_poller.return_value.poll.return_value = []
poller = SocketPoller(socket) poller = SocketPoller(socket)

View File

@@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import numpy as np import numpy as np
import pytest import pytest
from control_backend.agents.vad_agent import Streaming from control_backend.agents.per_agents.per_vad_agent import Streaming
@pytest.fixture @pytest.fixture
@@ -20,7 +20,7 @@ def audio_out_socket():
def mock_agent(mocker): def mock_agent(mocker):
"""Fixture to create a mock BDIAgent.""" """Fixture to create a mock BDIAgent."""
agent = MagicMock() agent = MagicMock()
agent.jid = "vad_agent@test" agent.jid = "per_vad_agent@test"
return agent return agent

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
from control_backend.agents.transcription.speech_recognizer import ( from control_backend.agents.per_agents.per_transcription_agent.speech_recognizer import (
OpenAIWhisperSpeechRecognizer, OpenAIWhisperSpeechRecognizer,
SpeechRecognizer, SpeechRecognizer,
) )