diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..d498054 --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# Example .env file. To use, make a copy, call it ".env" (i.e. removing the ".example" suffix), then you edit values. + +# The hostname of the Robot Interface. Change if the Control Backend and Robot Interface are running on different computers. +RI_HOST="localhost" + +# URL for the local LLM API. Must be an API that implements the OpenAI Chat Completions API, but most do. +LLM_SETTINGS__LOCAL_LLM_URL="http://localhost:1234/v1/chat/completions" + +# Name of the local LLM model to use. +LLM_SETTINGS__LOCAL_LLM_MODEL="gpt-oss" + +# Number of non-speech chunks to wait before speech ended. A chunk is approximately 31 ms. Increasing this number allows longer pauses in speech, but also increases response time. +BEHAVIOUR_SETTINGS__VAD_NON_SPEECH_PATIENCE_CHUNKS=3 + +# Timeout in milliseconds for socket polling. Increase this number if network latency/jitter is high, often the case when using Wi-Fi. Perhaps 500 ms. A symptom of this issue is transcriptions getting cut off. +BEHAVIOUR_SETTINGS__SOCKET_POLLER_TIMEOUT_MS=100 + + + +# For an exhaustive list of options, see the control_backend.core.config module in the docs. diff --git a/README.md b/README.md index 1527215..03dac9a 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ This + part might differ based on what model you choose. copy the model name in the module loaded and replace local_llm_modelL. In settings. + ## Running To run the project (development server), execute the following command (while inside the root repository): @@ -34,6 +35,14 @@ To run the project (development server), execute the following command (while in uv run fastapi dev src/control_backend/main.py ``` +### Environment Variables + +You can use environment variables to change settings. Make a copy of the [`.env.example`](.env.example) file, name it `.env` and put it in the root directory. The file itself describes how to do the configuration. + +For an exhaustive list of environment options, see the `control_backend.core.config` module in the docs. + + + ## Testing Testing happens automatically when opening a merge request to any branch. If you want to manually run the test suite, you can do so by running the following for unit tests: diff --git a/src/control_backend/agents/actuation/robot_gesture_agent.py b/src/control_backend/agents/actuation/robot_gesture_agent.py index b7cb9eb..256389a 100644 --- a/src/control_backend/agents/actuation/robot_gesture_agent.py +++ b/src/control_backend/agents/actuation/robot_gesture_agent.py @@ -36,7 +36,7 @@ class RobotGestureAgent(BaseAgent): def __init__( self, name: str, - address=settings.zmq_settings.ri_command_address, + address: str, bind=False, gesture_tags=None, gesture_basic=None, @@ -135,7 +135,7 @@ class RobotGestureAgent(BaseAgent): body = json.loads(body) gesture_command = GestureCommand.model_validate(body) if gesture_command.endpoint == RIEndpoint.GESTURE_TAG: - if gesture_command.data not in self.gesture_data: + if gesture_command.data not in (self.gesture_tags or self.gesture_single): self.logger.warning( "Received gesture tag '%s' which is not in available tags.\ Early returning", diff --git a/src/control_backend/agents/bdi/bdi_program_manager.py b/src/control_backend/agents/bdi/bdi_program_manager.py index 83dea93..2f4f850 100644 --- a/src/control_backend/agents/bdi/bdi_program_manager.py +++ b/src/control_backend/agents/bdi/bdi_program_manager.py @@ -60,24 +60,41 @@ class BDIProgramManager(BaseAgent): await self.send(message) self.logger.debug("Sent new norms and goals to the BDI agent.") + async def _send_clear_llm_history(self): + """ + Clear the LLM Agent's conversation history. + + Sends an empty history to the LLM Agent to reset its state. + """ + message = InternalMessage( + to=settings.agent_settings.llm_name, + sender=self.name, + body="clear_history", + threads="clear history message", + ) + await self.send(message) + self.logger.debug("Sent message to LLM agent to clear history.") + async def _receive_programs(self): """ Continuous loop that receives program updates from the HTTP endpoint. It listens to the ``program`` topic on the internal ZMQ SUB socket. When a program is received, it is validated and forwarded to BDI via :meth:`_send_to_bdi`. + Additionally, the LLM history is cleared via :meth:`_send_clear_llm_history`. """ while True: topic, body = await self.sub_socket.recv_multipart() try: program = Program.model_validate_json(body) + await self._send_to_bdi(program) + await self._send_clear_llm_history() + except ValidationError: self.logger.exception("Received an invalid program.") continue - await self._send_to_bdi(program) - async def setup(self): """ Initialize the agent. diff --git a/src/control_backend/agents/communication/ri_communication_agent.py b/src/control_backend/agents/communication/ri_communication_agent.py index ba0c8e0..74ca6a5 100644 --- a/src/control_backend/agents/communication/ri_communication_agent.py +++ b/src/control_backend/agents/communication/ri_communication_agent.py @@ -38,7 +38,7 @@ class RICommunicationAgent(BaseAgent): def __init__( self, name: str, - address=settings.zmq_settings.ri_command_address, + address=settings.zmq_settings.ri_communication_address, bind=False, ): super().__init__(name) @@ -168,7 +168,7 @@ class RICommunicationAgent(BaseAgent): bind = port_data["bind"] if not bind: - addr = f"tcp://localhost:{port}" + addr = f"tcp://{settings.ri_host}:{port}" else: addr = f"tcp://*:{port}" diff --git a/src/control_backend/agents/llm/llm_agent.py b/src/control_backend/agents/llm/llm_agent.py index 55099e2..60c585f 100644 --- a/src/control_backend/agents/llm/llm_agent.py +++ b/src/control_backend/agents/llm/llm_agent.py @@ -52,6 +52,10 @@ class LLMAgent(BaseAgent): await self._process_bdi_message(prompt_message) except ValidationError: self.logger.debug("Prompt message from BDI core is invalid.") + elif msg.sender == settings.agent_settings.bdi_program_manager_name: + if msg.body == "clear_history": + self.logger.debug("Clearing conversation history.") + self.history.clear() else: self.logger.debug("Message ignored (not from BDI core.") diff --git a/src/control_backend/agents/perception/vad_agent.py b/src/control_backend/agents/perception/vad_agent.py index 8ccff0a..70fa9e1 100644 --- a/src/control_backend/agents/perception/vad_agent.py +++ b/src/control_backend/agents/perception/vad_agent.py @@ -103,12 +103,11 @@ class VADAgent(BaseAgent): self._connect_audio_in_socket() - audio_out_port = self._connect_audio_out_socket() - if audio_out_port is None: + audio_out_address = self._connect_audio_out_socket() + if audio_out_address is None: self.logger.error("Could not bind output socket, stopping.") await self.stop() return - audio_out_address = f"tcp://localhost:{audio_out_port}" # Connect to internal communication socket self.program_sub_socket = azmq.Context.instance().socket(zmq.SUB) @@ -161,13 +160,14 @@ class VADAgent(BaseAgent): self.audio_in_socket.connect(self.audio_in_address) self.audio_in_poller = SocketPoller[bytes](self.audio_in_socket) - def _connect_audio_out_socket(self) -> int | None: + def _connect_audio_out_socket(self) -> str | None: """ - Returns the port bound, or None if binding failed. + Returns the address that was bound to, or None if binding failed. """ try: self.audio_out_socket = azmq.Context.instance().socket(zmq.PUB) - return self.audio_out_socket.bind_to_random_port("tcp://localhost", max_tries=100) + self.audio_out_socket.bind(settings.zmq_settings.vad_pub_address) + return settings.zmq_settings.vad_pub_address except zmq.ZMQBindError: self.logger.error("Failed to bind an audio output socket after 100 tries.") self.audio_out_socket = None diff --git a/src/control_backend/core/config.py b/src/control_backend/core/config.py index 927985b..c4a4db7 100644 --- a/src/control_backend/core/config.py +++ b/src/control_backend/core/config.py @@ -1,3 +1,12 @@ +""" +An exhaustive overview of configurable options. All of these can be set using environment variables +by nesting with double underscores (__). Start from the ``Settings`` class. + +For example, ``settings.ri_host`` becomes ``RI_HOST``, and +``settings.zmq_settings.ri_communication_address`` becomes +``ZMQ_SETTINGS__RI_COMMUNICATION_ADDRESS``. +""" + from pydantic import BaseModel from pydantic_settings import BaseSettings, SettingsConfigDict @@ -8,16 +17,17 @@ class ZMQSettings(BaseModel): :ivar internal_pub_address: Address for the internal PUB socket. :ivar internal_sub_address: Address for the internal SUB socket. - :ivar ri_command_address: Address for sending commands to the Robot Interface. - :ivar ri_communication_address: Address for receiving communication from the Robot Interface. - :ivar vad_agent_address: Address for the Voice Activity Detection (VAD) agent. + :ivar ri_communication_address: Address for the endpoint that the Robot Interface connects to. + :ivar vad_pub_address: Address that the VAD agent binds to and publishes audio segments to. """ + # ATTENTION: When adding/removing settings, make sure to update the .env.example file + internal_pub_address: str = "tcp://localhost:5560" internal_sub_address: str = "tcp://localhost:5561" - ri_command_address: str = "tcp://localhost:0000" ri_communication_address: str = "tcp://*:5555" internal_gesture_rep_adress: str = "tcp://localhost:7788" + vad_pub_address: str = "inproc://vad_stream" class AgentSettings(BaseModel): @@ -36,6 +46,8 @@ class AgentSettings(BaseModel): :ivar robot_speech_name: Name of the Robot Speech Agent. """ + # ATTENTION: When adding/removing settings, make sure to update the .env.example file + # agent names bdi_core_name: str = "bdi_core_agent" bdi_belief_collector_name: str = "belief_collector_agent" @@ -67,6 +79,8 @@ class BehaviourSettings(BaseModel): :ivar transcription_token_buffer: Buffer for transcription tokens. """ + # ATTENTION: When adding/removing settings, make sure to update the .env.example file + sleep_s: float = 1.0 comm_setup_max_retries: int = 5 socket_poller_timeout_ms: int = 100 @@ -91,6 +105,8 @@ class LLMSettings(BaseModel): :ivar local_llm_model: Name of the local LLM model to use. """ + # ATTENTION: When adding/removing settings, make sure to update the .env.example file + local_llm_url: str = "http://localhost:1234/v1/chat/completions" local_llm_model: str = "gpt-oss" @@ -104,6 +120,8 @@ class VADSettings(BaseModel): :ivar sample_rate_hz: Sample rate in Hz for the VAD model. """ + # ATTENTION: When adding/removing settings, make sure to update the .env.example file + repo_or_dir: str = "snakers4/silero-vad" model_name: str = "silero_vad" sample_rate_hz: int = 16000 @@ -117,6 +135,8 @@ class SpeechModelSettings(BaseModel): :ivar openai_model_name: Model name for OpenAI-based speech recognition. """ + # ATTENTION: When adding/removing settings, make sure to update the .env.example file + # model identifiers for speech recognition mlx_model_name: str = "mlx-community/whisper-small.en-mlx" openai_model_name: str = "small.en" @@ -128,6 +148,7 @@ class Settings(BaseSettings): :ivar app_title: Title of the application. :ivar ui_url: URL of the frontend UI. + :ivar ri_host: The hostname of the Robot Interface. :ivar zmq_settings: ZMQ configuration. :ivar agent_settings: Agent name configuration. :ivar behaviour_settings: Behavior configuration. @@ -140,6 +161,8 @@ class Settings(BaseSettings): ui_url: str = "http://localhost:5173" + ri_host: str = "localhost" + zmq_settings: ZMQSettings = ZMQSettings() agent_settings: AgentSettings = AgentSettings() diff --git a/test/integration/agents/perception/vad_agent/test_vad_agent.py b/test/integration/agents/perception/vad_agent/test_vad_agent.py index f5f2615..668d1ce 100644 --- a/test/integration/agents/perception/vad_agent/test_vad_agent.py +++ b/test/integration/agents/perception/vad_agent/test_vad_agent.py @@ -91,7 +91,7 @@ def test_out_socket_creation(zmq_context): assert per_vad_agent.audio_out_socket is not None zmq_context.return_value.socket.assert_called_once_with(zmq.PUB) - zmq_context.return_value.socket.return_value.bind_to_random_port.assert_called_once() + zmq_context.return_value.socket.return_value.bind.assert_called_once_with("inproc://vad_stream") @pytest.mark.asyncio diff --git a/test/unit/agents/actuation/test_robot_gesture_agent.py b/test/unit/agents/actuation/test_robot_gesture_agent.py index b7a9281..1211d07 100644 --- a/test/unit/agents/actuation/test_robot_gesture_agent.py +++ b/test/unit/agents/actuation/test_robot_gesture_agent.py @@ -11,7 +11,6 @@ from control_backend.schemas.ri_message import RIEndpoint @pytest.fixture def zmq_context(mocker): - """Mock the ZMQ context.""" mock_context = mocker.patch( "control_backend.agents.actuation.robot_gesture_agent.azmq.Context.instance" ) @@ -59,19 +58,16 @@ async def test_setup_connect(zmq_context, mocker): @pytest.mark.asyncio -async def test_handle_message_forwards_valid_command(): +async def test_handle_message_valid_gesture_tag(): pubsocket = AsyncMock() - agent = RobotGestureAgent( "robot_gesture", + address="", gesture_tags=["hello"], ) agent.pubsocket = pubsocket - payload = { - "endpoint": RIEndpoint.GESTURE_TAG, - "data": "hello", - } + payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"} msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) await agent.handle_message(msg) @@ -80,13 +76,31 @@ async def test_handle_message_forwards_valid_command(): @pytest.mark.asyncio -async def test_handle_message_invalid_payload(): +async def test_handle_message_invalid_gesture_tag(): pubsocket = AsyncMock() - agent = RobotGestureAgent("robot_gesture") + agent = RobotGestureAgent( + "robot_gesture", + address="", + gesture_tags=["hello"], + ) + agent.pubsocket = pubsocket + + payload = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "nope"} + msg = InternalMessage(to="robot", sender="tester", body=json.dumps(payload)) + + await agent.handle_message(msg) + + pubsocket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_handle_message_invalid_payload_logged(): + pubsocket = AsyncMock() + agent = RobotGestureAgent("robot_gesture", address="") agent.pubsocket = pubsocket agent.logger = MagicMock() - msg = InternalMessage(to="robot", sender="tester", body=json.dumps({"bad": "data"})) + msg = InternalMessage(to="robot", sender="tester", body="not json") await agent.handle_message(msg) @@ -95,22 +109,25 @@ async def test_handle_message_invalid_payload(): @pytest.mark.asyncio -async def test_zmq_command_loop_valid_payload(): - """UI command with valid payload is published.""" - command = {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"} +async def test_zmq_command_loop_valid_gesture(): fake_socket = AsyncMock() async def recv_once(): agent._running = False - return (b"command", json.dumps(command).encode("utf-8")) + return b"command", json.dumps( + {"endpoint": RIEndpoint.GESTURE_TAG, "data": "hello"} + ).encode() fake_socket.recv_multipart = recv_once fake_socket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture") + agent = RobotGestureAgent( + "robot_gesture", + address="", + gesture_tags=["hello"], + ) agent.subsocket = fake_socket agent.pubsocket = fake_socket - agent.gesture_data = ["hello", "yes", "no"] # ← REQUIRED for legacy check agent._running = True await agent._zmq_command_loop() @@ -119,17 +136,23 @@ async def test_zmq_command_loop_valid_payload(): @pytest.mark.asyncio -async def test_zmq_command_loop_ignores_send_gestures(): +async def test_zmq_command_loop_invalid_tag(): fake_socket = AsyncMock() async def recv_once(): agent._running = False - return (b"send_gestures", b"{}") + return b"command", json.dumps( + {"endpoint": RIEndpoint.GESTURE_TAG, "data": "invalid"} + ).encode() fake_socket.recv_multipart = recv_once fake_socket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture") + agent = RobotGestureAgent( + "robot_gesture", + address="", + gesture_tags=["hello"], + ) agent.subsocket = fake_socket agent.pubsocket = fake_socket agent._running = True @@ -140,7 +163,28 @@ async def test_zmq_command_loop_ignores_send_gestures(): @pytest.mark.asyncio -async def test_fetch_gestures_tags_all(): +async def test_zmq_command_loop_ignores_send_gestures_topic(): + fake_socket = AsyncMock() + + async def recv_once(): + agent._running = False + return b"send_gestures", b"{}" + + fake_socket.recv_multipart = recv_once + fake_socket.send_json = AsyncMock() + + agent = RobotGestureAgent("robot_gesture", address="") + agent.subsocket = fake_socket + agent.pubsocket = fake_socket + agent._running = True + + await agent._zmq_command_loop() + + fake_socket.send_json.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_fetch_gestures_tags(): fake_repsocket = AsyncMock() async def recv_once(): @@ -152,6 +196,7 @@ async def test_fetch_gestures_tags_all(): agent = RobotGestureAgent( "robot_gesture", + address="", gesture_tags=["hello", "yes", "no"], ) agent.repsocket = fake_repsocket @@ -163,31 +208,7 @@ async def test_fetch_gestures_tags_all(): @pytest.mark.asyncio -async def test_fetch_gestures_tags_with_count(): - fake_repsocket = AsyncMock() - - async def recv_once(): - agent._running = False - return {"type": "tags", "count": 2} - - fake_repsocket.recv_json = recv_once - fake_repsocket.send_json = AsyncMock() - - agent = RobotGestureAgent( - "robot_gesture", - gesture_tags=["hello", "yes", "no"], - ) - agent.repsocket = fake_repsocket - agent._running = True - - await agent._fetch_gestures_loop() - - fake_repsocket.send_json.assert_awaited_once_with({"tags": ["hello", "yes"]}) - - -@pytest.mark.asyncio -async def test_fetch_gestures_basic_new(): - """NEW: fetch basic gestures""" +async def test_fetch_gestures_basic(): fake_repsocket = AsyncMock() async def recv_once(): @@ -199,6 +220,7 @@ async def test_fetch_gestures_basic_new(): agent = RobotGestureAgent( "robot_gesture", + address="", gesture_basic=["wave", "point"], ) agent.repsocket = fake_repsocket @@ -220,48 +242,10 @@ async def test_fetch_gestures_unknown_type(): fake_repsocket.recv_json = recv_once fake_repsocket.send_json = AsyncMock() - agent = RobotGestureAgent("robot_gesture") + agent = RobotGestureAgent("robot_gesture", address="") agent.repsocket = fake_repsocket agent._running = True await agent._fetch_gestures_loop() fake_repsocket.send_json.assert_awaited_once_with({}) - - -@pytest.mark.asyncio -async def test_fetch_gestures_exception_logged(): - fake_repsocket = AsyncMock() - - async def recv_once(): - agent._running = False - raise Exception("boom") - - fake_repsocket.recv_json = recv_once - fake_repsocket.send_json = AsyncMock() - - agent = RobotGestureAgent("robot_gesture") - agent.repsocket = fake_repsocket - agent.logger = MagicMock() - agent._running = True - - await agent._fetch_gestures_loop() - - agent.logger.exception.assert_called_once() - - -@pytest.mark.asyncio -async def test_stop_closes_sockets(): - pubsocket = MagicMock() - subsocket = MagicMock() - repsocket = MagicMock() - - agent = RobotGestureAgent("robot_gesture") - agent.pubsocket = pubsocket - agent.subsocket = subsocket - agent.repsocket = repsocket - - await agent.stop() - - pubsocket.close.assert_called_once() - subsocket.close.assert_called_once() diff --git a/test/unit/agents/bdi/test_bdi_program_manager.py b/test/unit/agents/bdi/test_bdi_program_manager.py index a54360c..573524e 100644 --- a/test/unit/agents/bdi/test_bdi_program_manager.py +++ b/test/unit/agents/bdi/test_bdi_program_manager.py @@ -63,6 +63,7 @@ async def test_receive_programs_valid_and_invalid(): manager = BDIProgramManager(name="program_manager_test") manager.sub_socket = sub manager._send_to_bdi = AsyncMock() + manager._send_clear_llm_history = AsyncMock() try: # Will give StopAsyncIteration when the predefined `sub.recv_multipart` side-effects run out @@ -75,3 +76,24 @@ async def test_receive_programs_valid_and_invalid(): forwarded: Program = manager._send_to_bdi.await_args[0][0] assert forwarded.phases[0].norms[0].norm == "N1" assert forwarded.phases[0].goals[0].description == "G1" + + # Verify history clear was triggered + assert manager._send_clear_llm_history.await_count == 1 + + +@pytest.mark.asyncio +async def test_send_clear_llm_history(mock_settings): + # Ensure the mock returns a string for the agent name (just like in your LLM tests) + mock_settings.agent_settings.llm_agent_name = "llm_agent" + + manager = BDIProgramManager(name="program_manager_test") + manager.send = AsyncMock() + + await manager._send_clear_llm_history() + + assert manager.send.await_count == 1 + msg: InternalMessage = manager.send.await_args[0][0] + + # Verify the content and recipient + assert msg.body == "clear_history" + assert msg.to == "llm_agent" diff --git a/test/unit/agents/llm/test_llm_agent.py b/test/unit/agents/llm/test_llm_agent.py index 5e84d8d..ef8a3bf 100644 --- a/test/unit/agents/llm/test_llm_agent.py +++ b/test/unit/agents/llm/test_llm_agent.py @@ -265,3 +265,23 @@ async def test_stream_query_llm_skips_non_data_lines(mock_httpx_client, mock_set # Only the valid 'data:' line should yield content assert tokens == ["Hi"] + + +@pytest.mark.asyncio +async def test_clear_history_command(mock_settings): + """Test that the 'clear_history' message clears the agent's memory.""" + # setup LLM to have some history + mock_settings.agent_settings.bdi_program_manager_name = "bdi_program_manager_agent" + agent = LLMAgent("llm_agent") + agent.history = [ + {"role": "user", "content": "Old conversation context"}, + {"role": "assistant", "content": "Old response"}, + ] + assert len(agent.history) == 2 + msg = InternalMessage( + to="llm_agent", + sender=mock_settings.agent_settings.bdi_program_manager_name, + body="clear_history", + ) + await agent.handle_message(msg) + assert len(agent.history) == 0 diff --git a/test/unit/agents/perception/vad_agent/test_vad_streaming.py b/test/unit/agents/perception/vad_agent/test_vad_streaming.py index 4440cae..166919f 100644 --- a/test/unit/agents/perception/vad_agent/test_vad_streaming.py +++ b/test/unit/agents/perception/vad_agent/test_vad_streaming.py @@ -7,6 +7,15 @@ import zmq from control_backend.agents.perception.vad_agent import VADAgent +# We don't want to use real ZMQ in unit tests, for example because it can give errors when sockets +# aren't closed properly. +@pytest.fixture(autouse=True) +def mock_zmq(): + with patch("zmq.asyncio.Context") as mock: + mock.instance.return_value = MagicMock() + yield mock + + @pytest.fixture def audio_out_socket(): return AsyncMock() @@ -140,12 +149,10 @@ async def test_vad_model_load_failure_stops_agent(vad_agent): # Patch stop to an AsyncMock so we can check it was awaited vad_agent.stop = AsyncMock() - result = await vad_agent.setup() + await vad_agent.setup() # Assert stop was called vad_agent.stop.assert_awaited_once() - # Assert setup returned None - assert result is None @pytest.mark.asyncio @@ -155,7 +162,7 @@ async def test_audio_out_bind_failure_sets_none_and_logs(vad_agent, caplog): audio_out_socket is set to None, None is returned, and an error is logged. """ mock_socket = MagicMock() - mock_socket.bind_to_random_port.side_effect = zmq.ZMQBindError() + mock_socket.bind.side_effect = zmq.ZMQBindError() with patch("control_backend.agents.perception.vad_agent.azmq.Context.instance") as mock_ctx: mock_ctx.return_value.socket.return_value = mock_socket