chore: apply suggestion
Changed `add_background_task` to `add_behavior` and added extra docs.
This commit is contained in:
@@ -45,7 +45,7 @@ class RobotSpeechAgent(BaseAgent):
|
|||||||
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
self.subsocket.connect(settings.zmq_settings.internal_sub_address)
|
||||||
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
self.subsocket.setsockopt(zmq.SUBSCRIBE, b"command")
|
||||||
|
|
||||||
await self.add_background_task(self._zmq_command_loop())
|
await self.add_behavior(self._zmq_command_loop())
|
||||||
|
|
||||||
self.logger.info("Finished setting up %s", self.name)
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class BDICoreAgent(BaseAgent):
|
|||||||
await self._load_asl()
|
await self._load_asl()
|
||||||
|
|
||||||
# Start the BDI cycle loop
|
# Start the BDI cycle loop
|
||||||
await self.add_background_task(self._bdi_loop())
|
await self.add_behavior(self._bdi_loop())
|
||||||
self.logger.debug("Setup complete.")
|
self.logger.debug("Setup complete.")
|
||||||
|
|
||||||
async def _load_asl(self):
|
async def _load_asl(self):
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class RICommunicationAgent(BaseAgent):
|
|||||||
|
|
||||||
if await self._negotiate_connection():
|
if await self._negotiate_connection():
|
||||||
self.connected = True
|
self.connected = True
|
||||||
await self.add_background_task(self._listen_loop())
|
await self.add_behavior(self._listen_loop())
|
||||||
else:
|
else:
|
||||||
self.logger.warning("Failed to negotiate connection during setup.")
|
self.logger.warning("Failed to negotiate connection during setup.")
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class TranscriptionAgent(BaseAgent):
|
|||||||
self.speech_recognizer.load_model() # Warmup
|
self.speech_recognizer.load_model() # Warmup
|
||||||
|
|
||||||
# Start background loop
|
# Start background loop
|
||||||
await self.add_background_task(self._transcribing_loop())
|
await self.add_behavior(self._transcribing_loop())
|
||||||
|
|
||||||
self.logger.info("Finished setting up %s", self.name)
|
self.logger.info("Finished setting up %s", self.name)
|
||||||
|
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ class VADAgent(BaseAgent):
|
|||||||
# Warmup/reset
|
# Warmup/reset
|
||||||
await self.reset_stream()
|
await self.reset_stream()
|
||||||
|
|
||||||
await self.add_background_task(self._streaming_loop())
|
await self.add_behavior(self._streaming_loop())
|
||||||
|
|
||||||
# Start agents dependent on the output audio fragments here
|
# Start agents dependent on the output audio fragments here
|
||||||
transcriber = TranscriptionAgent(audio_out_address)
|
transcriber = TranscriptionAgent(audio_out_address)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Coroutine
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
# Central directory to resolve agent names to instances
|
# Central directory to resolve agent names to instances
|
||||||
@@ -69,7 +70,7 @@ class BaseAgent(ABC):
|
|||||||
await self.setup()
|
await self.setup()
|
||||||
|
|
||||||
# Start processing inbox
|
# Start processing inbox
|
||||||
await self.add_background_task(self._process_inbox())
|
await self.add_behavior(self._process_inbox())
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Stops the agent."""
|
"""Stops the agent."""
|
||||||
@@ -98,8 +99,13 @@ class BaseAgent(ABC):
|
|||||||
"""Override this to handle incoming messages."""
|
"""Override this to handle incoming messages."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def add_background_task(self, coro):
|
async def add_behavior(self, coro: Coroutine):
|
||||||
"""Helper to add a behavior to the agent."""
|
"""
|
||||||
|
Helper to add a behavior to the agent. To add asynchronous behavior to an agent, define
|
||||||
|
an `async` function and add it to the task list by calling :func:`add_background_task`
|
||||||
|
with it. This should happen in the :func:`setup` method of the agent. For an example, see:
|
||||||
|
:func:`~control_backend.agents.bdi.BDICoreAgent`.
|
||||||
|
"""
|
||||||
task = asyncio.create_task(coro)
|
task = asyncio.create_task(coro)
|
||||||
self._tasks.add(task)
|
self._tasks.add(task)
|
||||||
task.add_done_callback(self._tasks.discard)
|
task.add_done_callback(self._tasks.discard)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ async def test_normal_setup(per_transcription_agent):
|
|||||||
async def swallow_background_task(coro):
|
async def swallow_background_task(coro):
|
||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
per_vad_agent.add_background_task = swallow_background_task
|
per_vad_agent.add_behavior = swallow_background_task
|
||||||
per_vad_agent.reset_stream = AsyncMock()
|
per_vad_agent.reset_stream = AsyncMock()
|
||||||
|
|
||||||
await per_vad_agent.setup()
|
await per_vad_agent.setup()
|
||||||
@@ -110,7 +110,7 @@ async def test_out_socket_creation_failure(zmq_context):
|
|||||||
async def swallow_background_task(coro):
|
async def swallow_background_task(coro):
|
||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
per_vad_agent.add_background_task = swallow_background_task
|
per_vad_agent.add_behavior = swallow_background_task
|
||||||
|
|
||||||
await per_vad_agent.setup()
|
await per_vad_agent.setup()
|
||||||
|
|
||||||
@@ -130,7 +130,7 @@ async def test_stop(zmq_context, per_transcription_agent):
|
|||||||
async def swallow_background_task(coro):
|
async def swallow_background_task(coro):
|
||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
per_vad_agent.add_background_task = swallow_background_task
|
per_vad_agent.add_behavior = swallow_background_task
|
||||||
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
zmq_context.return_value.socket.return_value.bind_to_random_port.return_value = random.randint(
|
||||||
1000,
|
1000,
|
||||||
10000,
|
10000,
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ async def test_setup_bind(zmq_context, mocker):
|
|||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
swallow = Swallow()
|
swallow = Swallow()
|
||||||
agent.add_background_task = swallow
|
agent.add_behavior = swallow
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ async def test_setup_connect(zmq_context, mocker):
|
|||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
swallow = Swallow()
|
swallow = Swallow()
|
||||||
agent.add_background_task = swallow
|
agent.add_behavior = swallow
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ async def test_setup_success_connects_and_starts_robot(zmq_context):
|
|||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
swallow = Swallow()
|
swallow = Swallow()
|
||||||
agent.add_background_task = swallow
|
agent.add_behavior = swallow
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
@@ -85,7 +85,7 @@ async def test_setup_binds_when_requested(zmq_context):
|
|||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
swallow = Swallow()
|
swallow = Swallow()
|
||||||
agent.add_background_task = swallow
|
agent.add_behavior = swallow
|
||||||
|
|
||||||
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
with patch(speech_agent_path(), autospec=True) as MockRobot:
|
||||||
MockRobot.return_value.start = AsyncMock()
|
MockRobot.return_value.start = AsyncMock()
|
||||||
@@ -213,7 +213,7 @@ async def test_setup_warns_on_failed_negotiate(zmq_context, mocker):
|
|||||||
async def swallow(coro):
|
async def swallow(coro):
|
||||||
coro.close()
|
coro.close()
|
||||||
|
|
||||||
agent.add_background_task = swallow
|
agent.add_behavior = swallow
|
||||||
agent._negotiate_connection = AsyncMock(return_value=False)
|
agent._negotiate_connection = AsyncMock(return_value=False)
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ async def test_transcription_agent_flow(mock_zmq_context):
|
|||||||
agent.send = AsyncMock()
|
agent.send = AsyncMock()
|
||||||
|
|
||||||
agent._running = True
|
agent._running = True
|
||||||
agent.add_background_task = AsyncMock()
|
agent.add_behavior = AsyncMock()
|
||||||
|
|
||||||
await agent.setup()
|
await agent.setup()
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ async def test_agent_lifecycle():
|
|||||||
async def dummy_task():
|
async def dummy_task():
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
await agent.add_background_task(dummy_task())
|
await agent.add_behavior(dummy_task())
|
||||||
assert len(agent._tasks) > 0
|
assert len(agent._tasks) > 0
|
||||||
|
|
||||||
# Wait for task to finish
|
# Wait for task to finish
|
||||||
|
|||||||
Reference in New Issue
Block a user