refactor: agents inherit logger from BaseAgent

Created a class `BaseAgent`, from which all agents inherit. They get
assigned a logger with a nice name (something like
`control_backend.agents.AgentName`).

The BDI core takes care of its own logger, as bdi is still a module.

ref: N25B-241
This commit is contained in:
2025-11-04 20:48:55 +01:00
parent d43cb9394a
commit a98018ddda
15 changed files with 174 additions and 159 deletions

View File

@@ -1,4 +1,17 @@
from .base import BaseAgent
from .belief_collector.belief_collector import BeliefCollectorAgent from .belief_collector.belief_collector import BeliefCollectorAgent
from .llm.llm import LLMAgent from .llm.llm import LLMAgent
from .ri_command_agent import RICommandAgent
from .ri_communication_agent import RICommunicationAgent from .ri_communication_agent import RICommunicationAgent
from .transcription.transcription_agent import TranscriptionAgent
from .vad_agent import VADAgent from .vad_agent import VADAgent
__all__ = [
"BaseAgent",
"BeliefCollectorAgent",
"LLMAgent",
"RICommandAgent",
"RICommunicationAgent",
"TranscriptionAgent",
"VADAgent",
]

View File

@@ -0,0 +1,18 @@
import logging
from spade.agent import Agent
class BaseAgent(Agent):
"""
Base agent class for our agents to inherit from.
This ensures that all agents have a logger.
"""
logger: logging.Logger
# Whenever a subclass is initiated, give it the correct logger
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls.logger = logging.getLogger(__package__).getChild(cls.__name__)

View File

@@ -5,9 +5,10 @@ from spade.behaviour import OneShotBehaviour
from spade.message import Message from spade.message import Message
from spade_bdi.bdi import BDIAgent from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
from .behaviours.belief_setter import BeliefSetterBehaviour from .behaviours.belief_setter import BeliefSetterBehaviour
from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour
from control_backend.core.config import settings
class BDICoreAgent(BDIAgent): class BDICoreAgent(BDIAgent):
@@ -18,7 +19,7 @@ class BDICoreAgent(BDIAgent):
It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent. It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent.
""" """
logger = logging.getLogger("bdi_core_agent") logger = logging.getLogger(__package__).getChild(__name__)
async def setup(self) -> None: async def setup(self) -> None:
""" """
@@ -56,8 +57,8 @@ class BDICoreAgent(BDIAgent):
class SendBehaviour(OneShotBehaviour): class SendBehaviour(OneShotBehaviour):
async def run(self) -> None: async def run(self) -> None:
msg = Message( msg = Message(
to= settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host, to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body= text body=text,
) )
await self.send(msg) await self.send(msg)

View File

@@ -1,5 +1,4 @@
import json import json
import logging
from spade.agent import Message from spade.agent import Message
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
@@ -15,12 +14,11 @@ class BeliefSetterBehaviour(CyclicBehaviour):
""" """
agent: BDIAgent agent: BDIAgent
logger = logging.getLogger(__name__)
async def run(self): async def run(self):
"""Polls for messages and processes them.""" """Polls for messages and processes them."""
msg = await self.receive() msg = await self.receive()
self.logger.debug( self.agent.logger.debug(
"Received message from %s with thread '%s' and body: %s", "Received message from %s with thread '%s' and body: %s",
msg.sender, msg.sender,
msg.thread, msg.thread,
@@ -28,23 +26,24 @@ class BeliefSetterBehaviour(CyclicBehaviour):
) )
self._process_message(msg) self._process_message(msg)
def _process_message(self, message: Message): def _process_message(self, message: Message):
"""Routes the message to the correct processing function based on the sender.""" """Routes the message to the correct processing function based on the sender."""
sender = message.sender.node # removes host from jid and converts to str sender = message.sender.node # removes host from jid and converts to str
self.logger.debug("Processing message from sender: %s", sender) self.agent.logger.debug("Processing message from sender: %s", sender)
match sender: match sender:
case settings.agent_settings.belief_collector_agent_name: case settings.agent_settings.belief_collector_agent_name:
self.logger.debug("Message is from the belief collector agent. Processing as belief message.") self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message."
)
self._process_belief_message(message) self._process_belief_message(message)
case _: case _:
self.logger.debug("Not the belief agent, discarding message") self.agent.logger.debug("Not the belief agent, discarding message")
pass pass
def _process_belief_message(self, message: Message): def _process_belief_message(self, message: Message):
if not message.body: if not message.body:
self.logger.debug("Ignoring message with empty body from %s", message.sender.node) self.agent.logger.debug("Ignoring message with empty body from %s", message.sender.node)
return return
match message.thread: match message.thread:
@@ -53,10 +52,10 @@ class BeliefSetterBehaviour(CyclicBehaviour):
beliefs: dict[str, list[str]] = json.loads(message.body) beliefs: dict[str, list[str]] = json.loads(message.body)
self._set_beliefs(beliefs) self._set_beliefs(beliefs)
except json.JSONDecodeError: except json.JSONDecodeError:
self.logger.error( self.agent.logger.error(
"Could not decode beliefs from JSON. Message body: '%s'", "Could not decode beliefs from JSON. Message body: '%s'",
message.body, message.body,
exc_info=True exc_info=True,
) )
case _: case _:
pass pass
@@ -64,21 +63,23 @@ class BeliefSetterBehaviour(CyclicBehaviour):
def _set_beliefs(self, beliefs: dict[str, list[str]]): def _set_beliefs(self, beliefs: dict[str, list[str]]):
"""Removes previous values for beliefs and updates them with the provided values.""" """Removes previous values for beliefs and updates them with the provided values."""
if self.agent.bdi is None: if self.agent.bdi is None:
self.logger.warning("Cannot set beliefs; agent's BDI is not yet initialized.") self.agent.logger.warning("Cannot set beliefs; agent's BDI is not yet initialized.")
return return
if not beliefs: if not beliefs:
self.logger.debug("Received an empty set of beliefs. No beliefs were updated.") self.agent.logger.debug("Received an empty set of beliefs. No beliefs were updated.")
return return
# Set new beliefs (outdated beliefs are automatically removed) # Set new beliefs (outdated beliefs are automatically removed)
for belief, arguments in beliefs.items(): for belief, arguments in beliefs.items():
self.logger.debug("Setting belief %s with arguments %s", belief, arguments) self.agent.logger.debug("Setting belief %s with arguments %s", belief, arguments)
self.agent.bdi.set_belief(belief, *arguments) self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet # Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said": if belief == "user_said":
self.agent.bdi.set_belief("new_message") self.agent.bdi.set_belief("new_message")
self.logger.debug("Detected 'user_said' belief, also setting 'new_message' belief.") self.agent.logger.debug(
"Detected 'user_said' belief, also setting 'new_message' belief."
)
self.logger.info("Successfully updated %d beliefs.", len(beliefs)) self.agent.logger.info("Successfully updated %d beliefs.", len(beliefs))

View File

@@ -1,5 +1,3 @@
import logging
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings from control_backend.core.config import settings
@@ -9,7 +7,7 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
""" """
Adds behavior to receive responses from the LLM Agent. Adds behavior to receive responses from the LLM Agent.
""" """
logger = logging.getLogger(__name__)
async def run(self): async def run(self):
msg = await self.receive() msg = await self.receive()
@@ -17,8 +15,8 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
match sender: match sender:
case settings.agent_settings.llm_agent_name: case settings.agent_settings.llm_agent_name:
content = msg.body content = msg.body
self.logger.info("Received LLM response: %s", content) self.agent.logger.info("Received LLM response: %s", content)
#Here the BDI can pass the message back as a response # Here the BDI can pass the message back as a response
case _: case _:
self.logger.debug("Discarding message from %s", sender) self.agent.logger.debug("Discarding message from %s", sender)
pass pass

View File

@@ -1,6 +1,4 @@
import asyncio
import json import json
import logging
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from spade.message import Message from spade.message import Message
@@ -9,8 +7,6 @@ from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour): class BeliefFromText(CyclicBehaviour):
logger = logging.getLogger("Belief From Text")
# TODO: LLM prompt nog hardcoded # TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """ llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules: You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
@@ -25,16 +21,13 @@ class BeliefFromText(CyclicBehaviour):
""" """
# on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt # on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt
#async def on_start(self): # async def on_start(self):
# msg = await self.receive(timeout=0.1) # msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message # self.beliefs = dict uit message
# send instruction prompt to LLM # send instruction prompt to LLM
beliefs: dict[str, list[str]] beliefs: dict[str, list[str]]
beliefs = { beliefs = {"mood": ["X"], "car": ["Y"]}
"mood": ["X"],
"car": ["Y"]
}
async def run(self): async def run(self):
msg = await self.receive() msg = await self.receive()
@@ -56,8 +49,8 @@ class BeliefFromText(CyclicBehaviour):
prompt = text_prompt + beliefs_prompt prompt = text_prompt + beliefs_prompt
self.logger.info(prompt) self.logger.info(prompt)
#prompt_msg = Message(to="LLMAgent@whatever") # prompt_msg = Message(to="LLMAgent@whatever")
#response = self.send(prompt_msg) # response = self.send(prompt_msg)
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]] # Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
response = '{"mood": [["happy"]]}' response = '{"mood": [["happy"]]}'
@@ -65,15 +58,16 @@ class BeliefFromText(CyclicBehaviour):
try: try:
json.loads(response) json.loads(response)
belief_message = Message( belief_message = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=response) body=response,
)
belief_message.thread = "beliefs" belief_message.thread = "beliefs"
await self.send(belief_message) await self.send(belief_message)
self.logger.info("Sent beliefs to BDI.") self.agent.logger.info("Sent beliefs to BDI.")
except json.JSONDecodeError: except json.JSONDecodeError:
# Parsing failed, so the response is in the wrong format, log warning # Parsing failed, so the response is in the wrong format, log warning
self.logger.warning("Received LLM response in incorrect format.") self.agent.logger.warning("Received LLM response in incorrect format.")
async def _process_transcription_demo(self, txt: str): async def _process_transcription_demo(self, txt: str):
""" """
@@ -83,9 +77,12 @@ class BeliefFromText(CyclicBehaviour):
""" """
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"} belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief) payload = json.dumps(belief)
belief_msg = Message(to=settings.agent_settings.belief_collector_agent_name belief_msg = Message(
+ '@' + settings.agent_settings.host, to=settings.agent_settings.belief_collector_agent_name
body=payload) + "@"
+ settings.agent_settings.host,
body=payload,
)
belief_msg.thread = "beliefs" belief_msg.thread = "beliefs"
await self.send(belief_msg) await self.send(belief_msg)

View File

@@ -1,8 +1,8 @@
from spade.agent import Agent from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor import BeliefFromText from .behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractorAgent(Agent): class TBeliefExtractorAgent(BaseAgent):
async def setup(self): async def setup(self):
self.add_behaviour(BeliefFromText()) self.add_behaviour(BeliefFromText())

View File

@@ -1,23 +1,22 @@
import json import json
import logging
from json import JSONDecodeError from json import JSONDecodeError
from spade.behaviour import CyclicBehaviour
from spade.agent import Message from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings from control_backend.core.config import settings
class ContinuousBeliefCollector(CyclicBehaviour): class ContinuousBeliefCollector(CyclicBehaviour):
""" """
Continuously collects beliefs/emotions from extractor agents: Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent. Then we send a unified belief packet to the BDI agent.
""" """
logger = logging.getLogger(__name__)
async def run(self): async def run(self):
msg = await self.receive() msg = await self.receive()
await self._process_message(msg) await self._process_message(msg)
async def _process_message(self, msg: Message): async def _process_message(self, msg: Message):
sender_node = msg.sender.node sender_node = msg.sender.node
@@ -25,9 +24,8 @@ class ContinuousBeliefCollector(CyclicBehaviour):
try: try:
payload = json.loads(msg.body) payload = json.loads(msg.body)
except JSONDecodeError as e: except JSONDecodeError as e:
self.logger.warning( self.agent.logger.warning(
"Failed to parse JSON from %s. Body=%r Error=%s", "Failed to parse JSON from %s. Body=%r Error=%s", sender_node, msg.body, e
sender_node, msg.body, e
) )
return return
@@ -35,19 +33,19 @@ class ContinuousBeliefCollector(CyclicBehaviour):
# Prefer explicit 'type' field # Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock": if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node) self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node) await self._handle_belief_text(payload, sender_node)
#This is not implemented yet, but we keep the structure for future use # This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock": elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node) self.agent.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node) await self._handle_emo_text(payload, sender_node)
else: else:
self.logger.warning( self.agent.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", "Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
sender_node, msg_type
) )
async def _handle_belief_text(self, payload: dict, origin: str): async def _handle_belief_text(self, payload: dict, origin: str):
""" """
Expected payload: Expected payload:
@@ -61,23 +59,20 @@ class ContinuousBeliefCollector(CyclicBehaviour):
beliefs = payload.get("beliefs", {}) beliefs = payload.get("beliefs", {})
if not beliefs: if not beliefs:
self.logger.debug("Received empty beliefs set.") self.agent.logger.debug("Received empty beliefs set.")
return return
self.logger.debug("Forwarding %d beliefs.", len(beliefs)) self.agent.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items(): for belief_name, belief_list in beliefs.items():
for belief in belief_list: for belief in belief_list:
self.logger.debug(" - %s %s", belief_name,str(belief)) self.agent.logger.debug(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin) await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str): async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)""" """TODO: implement (after we have emotional recogntion)"""
pass pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None): async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
""" """
Sends a unified belief packet to the BDI agent. Sends a unified belief packet to the BDI agent.
@@ -90,6 +85,5 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs") msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs) msg.body = json.dumps(beliefs)
await self.send(msg) await self.send(msg)
self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs)) self.agent.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -1,13 +1,11 @@
import logging from control_backend.agents.base import BaseAgent
from spade.agent import Agent
from .behaviours.continuous_collect import ContinuousBeliefCollector from .behaviours.continuous_collect import ContinuousBeliefCollector
logger = logging.getLogger(__name__)
class BeliefCollectorAgent(Agent): class BeliefCollectorAgent(BaseAgent):
async def setup(self): async def setup(self):
logger.info("BeliefCollectorAgent starting (%s)", self.jid) self.logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI) # Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector()) self.add_behaviour(ContinuousBeliefCollector())
logger.info("BeliefCollectorAgent ready.") self.logger.info("BeliefCollectorAgent ready.")

View File

@@ -1,29 +1,22 @@
"""
LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
service and returning its responses back to the BDI Core Agent.
"""
import logging
from typing import Any from typing import Any
import httpx import httpx
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from spade.message import Message from spade.message import Message
from .llm_instructions import LLMInstructions from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from .llm_instructions import LLMInstructions
class LLMAgent(Agent):
class LLMAgent(BaseAgent):
""" """
Agent responsible for processing user text input and querying a locally Agent responsible for processing user text input and querying a locally
hosted LLM for text generation. Receives messages from the BDI Core Agent hosted LLM for text generation. Receives messages from the BDI Core Agent
and responds with processed LLM output. and responds with processed LLM output.
""" """
logger = logging.getLogger("llm_agent")
class ReceiveMessageBehaviour(CyclicBehaviour): class ReceiveMessageBehaviour(CyclicBehaviour):
""" """
Cyclic behaviour to continuously listen for incoming messages from Cyclic behaviour to continuously listen for incoming messages from
@@ -63,8 +56,8 @@ class LLMAgent(Agent):
Sends a response message back to the BDI Core Agent. Sends a response message back to the BDI Core Agent.
""" """
reply = Message( reply = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host, to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=msg body=msg,
) )
await self.send(reply) await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent") self.agent.logger.info("Reply sent to BDI Core Agent")
@@ -88,25 +81,21 @@ class LLMAgent(Agent):
json={ json={
"model": settings.llm_settings.local_llm_model, "model": settings.llm_settings.local_llm_model,
"messages": [ "messages": [
{ {"role": "developer", "content": developer_instruction},
"role": "developer", {"role": "user", "content": prompt},
"content": developer_instruction
},
{
"role": "user",
"content": prompt
}
], ],
"temperature": 0.3 "temperature": 0.3,
}, },
) )
try: try:
response.raise_for_status() response.raise_for_status()
data: dict[str, Any] = response.json() data: dict[str, Any] = response.json()
return data.get("choices", [{}])[0].get( return (
"message", {} data.get("choices", [{}])[0]
).get("content", "No response") .get("message", {})
.get("content", "No response")
)
except httpx.HTTPError as err: except httpx.HTTPError as err:
self.agent.logger.error("HTTP error: %s", err) self.agent.logger.error("HTTP error: %s", err)
return "LLM service unavailable." return "LLM service unavailable."

View File

@@ -1,9 +1,12 @@
import json import json
from spade.agent import Agent from spade.agent import Agent
from spade.behaviour import OneShotBehaviour from spade.behaviour import OneShotBehaviour
from spade.message import Message from spade.message import Message
from control_backend.core.config import settings from control_backend.core.config import settings
class BeliefTextAgent(Agent): class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour): class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self): async def run(self):
@@ -12,7 +15,15 @@ class BeliefTextAgent(Agent):
# Send multiple beliefs in one JSON payload # Send multiple beliefs in one JSON payload
payload = { payload = {
"type": "belief_extraction_text", "type": "belief_extraction_text",
"beliefs": {"user_said": ["hello test","Can you help me?","stop talking to me","No","Pepper do a dance"]} "beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
} }
msg = Message(to=to_jid) msg = Message(to=to_jid)

View File

@@ -1,17 +1,15 @@
import json import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
import zmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
class RICommandAgent(BaseAgent):
class RICommandAgent(Agent):
subsocket: zmq.Socket subsocket: zmq.Socket
pubsocket: zmq.Socket pubsocket: zmq.Socket
address = "" address = ""
@@ -47,13 +45,13 @@ class RICommandAgent(Agent):
# Send to the robot. # Send to the robot.
await self.agent.pubsocket.send_json(message.model_dump()) await self.agent.pubsocket.send_json(message.model_dump())
except Exception as e: except Exception as e:
logger.error("Error processing message: %s", e) self.logger.error("Error processing message: %s", e)
async def setup(self): async def setup(self):
""" """
Setup the command agent Setup the command agent
""" """
logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.jid)
# To the robot # To the robot
self.pubsocket = context.socket(zmq.PUB) self.pubsocket = context.socket(zmq.PUB)
@@ -71,4 +69,4 @@ class RICommandAgent(Agent):
commands_behaviour = self.SendCommandsBehaviour() commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour) self.add_behaviour(commands_behaviour)
logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,17 +1,16 @@
import asyncio import asyncio
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
import zmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context from control_backend.core.zmq_context import context
from .ri_command_agent import RICommandAgent from .ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__)
class RICommunicationAgent(BaseAgent):
class RICommunicationAgent(Agent):
req_socket: zmq.Socket req_socket: zmq.Socket
_address = "" _address = ""
_bind = True _bind = True
@@ -45,13 +44,13 @@ class RICommunicationAgent(Agent):
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0) message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :( # We didnt get a reply :(
except asyncio.TimeoutError as e: except TimeoutError:
logger.info("No ping retrieved in 3 seconds, killing myself.") self.agent.logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill() self.kill()
logger.debug('Received message "%s"', message) self.agent.logger.debug('Received message "%s"', message)
if "endpoint" not in message: if "endpoint" not in message:
logger.error("No received endpoint in message, excepted ping endpoint.") self.agent.logger.error("No received endpoint in message, excepted ping endpoint.")
return return
# See what endpoint we received # See what endpoint we received
@@ -59,7 +58,7 @@ class RICommunicationAgent(Agent):
case "ping": case "ping":
await asyncio.sleep(1) await asyncio.sleep(1)
case _: case _:
logger.info( self.agent.logger.info(
"Received message with topic different than ping, while ping expected." "Received message with topic different than ping, while ping expected."
) )
@@ -67,7 +66,7 @@ class RICommunicationAgent(Agent):
""" """
Try to setup the communication agent, we have 5 retries in case we dont have a response yet. Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
""" """
logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.jid)
retries = 0 retries = 0
# Let's try a certain amount of times before failing connection # Let's try a certain amount of times before failing connection
@@ -86,8 +85,8 @@ class RICommunicationAgent(Agent):
try: try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0) received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError: except TimeoutError:
logger.warning( self.logger.warning(
"No connection established in 20 seconds (attempt %d/%d)", "No connection established in 20 seconds (attempt %d/%d)",
retries + 1, retries + 1,
max_retries, max_retries,
@@ -96,7 +95,7 @@ class RICommunicationAgent(Agent):
continue continue
except Exception as e: except Exception as e:
logger.error("Unexpected error during negotiation: %s", e) self.logger.error("Unexpected error during negotiation: %s", e)
retries += 1 retries += 1
continue continue
@@ -104,7 +103,7 @@ class RICommunicationAgent(Agent):
endpoint = received_message.get("endpoint") endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports": if endpoint != "negotiate/ports":
# TODO: Should this send a message back? # TODO: Should this send a message back?
logger.error( self.logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)", "Invalid endpoint '%s' received (attempt %d/%d)",
endpoint, endpoint,
retries + 1, retries + 1,
@@ -143,10 +142,10 @@ class RICommunicationAgent(Agent):
) )
await ri_commands_agent.start() await ri_commands_agent.start()
case _: case _:
logger.warning("Unhandled negotiation id: %s", id) self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e: except Exception as e:
logger.error("Error unpacking negotiation data: %s", e) self.logger.error("Error unpacking negotiation data: %s", e)
retries += 1 retries += 1
continue continue
@@ -154,10 +153,10 @@ class RICommunicationAgent(Agent):
break break
else: else:
logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries) self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
return return
# Set up ping behaviour # Set up ping behaviour
listen_behaviour = self.ListenBehaviour() listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour) self.add_behaviour(listen_behaviour)
logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,21 +1,19 @@
import asyncio import asyncio
import logging
import numpy as np import numpy as np
import zmq import zmq
import zmq.asyncio as azmq import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from spade.message import Message from spade.message import Message
from .speech_recognizer import SpeechRecognizer from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__) from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(Agent): class TranscriptionAgent(BaseAgent):
""" """
An agent which listens to audio fragments with voice, transcribes them, and sends the An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents. transcription to other agents.
@@ -47,7 +45,8 @@ class TranscriptionAgent(Agent):
"""Share a transcription to the other agents that depend on it.""" """Share a transcription to the other agents that depend on it."""
receiver_jids = [ receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name settings.agent_settings.text_belief_extractor_agent_name
+ '@' + settings.agent_settings.host, + "@"
+ settings.agent_settings.host,
] # Set message receivers here ] # Set message receivers here
for receiver_jid in receiver_jids: for receiver_jid in receiver_jids:
@@ -58,7 +57,7 @@ class TranscriptionAgent(Agent):
audio = await self.audio_in_socket.recv() audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32) audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio) speech = await self._transcribe(audio)
logger.info("Transcribed speech: %s", speech) self.agent.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech) await self._share_transcription(speech)
@@ -73,7 +72,7 @@ class TranscriptionAgent(Agent):
self.audio_in_socket.connect(self.audio_in_address) self.audio_in_socket.connect(self.audio_in_address)
async def setup(self): async def setup(self):
logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket() self._connect_audio_in_socket()
@@ -81,4 +80,4 @@ class TranscriptionAgent(Agent):
transcribing.warmup() transcribing.warmup()
self.add_behaviour(transcribing) self.add_behaviour(transcribing)
logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,17 +1,14 @@
import logging
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
import zmq.asyncio as azmq import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour from spade.behaviour import CyclicBehaviour
from .transcription.transcription_agent import TranscriptionAgent from control_backend.agents import BaseAgent
from control_backend.core.config import settings from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__) from .transcription.transcription_agent import TranscriptionAgent
class SocketPoller[T]: class SocketPoller[T]:
@@ -60,7 +57,9 @@ class Streaming(CyclicBehaviour):
data = await self.audio_in_poller.poll() data = await self.audio_in_poller.poll()
if data is None: if data is None:
if len(self.audio_buffer) > 0: if len(self.audio_buffer) > 0:
logger.debug("No audio data received. Discarding buffer until new data arrives.") self.agent.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32) self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100 self.i_since_speech = 100
return return
@@ -71,7 +70,7 @@ class Streaming(CyclicBehaviour):
if prob > 0.5: if prob > 0.5:
if self.i_since_speech > 3: if self.i_since_speech > 3:
logger.debug("Speech started.") self.agent.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk) self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0 self.i_since_speech = 0
return return
@@ -84,7 +83,7 @@ class Streaming(CyclicBehaviour):
# Speech probably ended. Make sure we have a usable amount of data. # Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk): if len(self.audio_buffer) >= 3 * len(chunk):
logger.debug("Speech ended.") self.agent.logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes()) await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended. # At this point, we know that the speech has ended.
@@ -92,7 +91,7 @@ class Streaming(CyclicBehaviour):
self.audio_buffer = chunk self.audio_buffer = chunk
class VADAgent(Agent): class VADAgent(BaseAgent):
""" """
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ. fragments with detected speech to other agents over ZeroMQ.
@@ -135,12 +134,12 @@ class VADAgent(Agent):
self.audio_out_socket = zmq_context.socket(zmq.PUB) self.audio_out_socket = zmq_context.socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100) return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError: except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.") self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None self.audio_out_socket = None
return None return None
async def setup(self): async def setup(self):
logger.info("Setting up %s", self.jid) self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket() self._connect_audio_in_socket()
@@ -157,4 +156,4 @@ class VADAgent(Agent):
transcriber = TranscriptionAgent(audio_out_address) transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start() await transcriber.start()
logger.info("Finished setting up %s", self.jid) self.logger.info("Finished setting up %s", self.jid)