refactor: agents inherit logger from BaseAgent

Created a class `BaseAgent`, from which all agents inherit. They get
assigned a logger with a nice name (something like
`control_backend.agents.AgentName`).

The BDI core takes care of its own logger, as bdi is still a module.

ref: N25B-241
This commit is contained in:
2025-11-04 20:48:55 +01:00
parent d43cb9394a
commit a98018ddda
15 changed files with 174 additions and 159 deletions

View File

@@ -1,4 +1,17 @@
from .base import BaseAgent
from .belief_collector.belief_collector import BeliefCollectorAgent
from .llm.llm import LLMAgent
from .ri_command_agent import RICommandAgent
from .ri_communication_agent import RICommunicationAgent
from .vad_agent import VADAgent
from .transcription.transcription_agent import TranscriptionAgent
from .vad_agent import VADAgent
__all__ = [
"BaseAgent",
"BeliefCollectorAgent",
"LLMAgent",
"RICommandAgent",
"RICommunicationAgent",
"TranscriptionAgent",
"VADAgent",
]

View File

@@ -0,0 +1,18 @@
import logging
from spade.agent import Agent
class BaseAgent(Agent):
"""
Base agent class for our agents to inherit from.
This ensures that all agents have a logger.
"""
logger: logging.Logger
# Whenever a subclass is initiated, give it the correct logger
def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls.logger = logging.getLogger(__package__).getChild(cls.__name__)

View File

@@ -5,9 +5,10 @@ from spade.behaviour import OneShotBehaviour
from spade.message import Message
from spade_bdi.bdi import BDIAgent
from control_backend.core.config import settings
from .behaviours.belief_setter import BeliefSetterBehaviour
from .behaviours.receive_llm_resp_behaviour import ReceiveLLMResponseBehaviour
from control_backend.core.config import settings
class BDICoreAgent(BDIAgent):
@@ -18,7 +19,7 @@ class BDICoreAgent(BDIAgent):
It has the BeliefSetter behaviour and can aks and recieve requests from the LLM agent.
"""
logger = logging.getLogger("bdi_core_agent")
logger = logging.getLogger(__package__).getChild(__name__)
async def setup(self) -> None:
"""
@@ -56,11 +57,11 @@ class BDICoreAgent(BDIAgent):
class SendBehaviour(OneShotBehaviour):
async def run(self) -> None:
msg = Message(
to= settings.agent_settings.llm_agent_name + '@' + settings.agent_settings.host,
body= text
to=settings.agent_settings.llm_agent_name + "@" + settings.agent_settings.host,
body=text,
)
await self.send(msg)
self.agent.logger.info("Message sent to LLM agent: %s", text)
self.add_behaviour(SendBehaviour())
self.add_behaviour(SendBehaviour())

View File

@@ -1,5 +1,4 @@
import json
import logging
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
@@ -15,12 +14,11 @@ class BeliefSetterBehaviour(CyclicBehaviour):
"""
agent: BDIAgent
logger = logging.getLogger(__name__)
async def run(self):
"""Polls for messages and processes them."""
msg = await self.receive()
self.logger.debug(
self.agent.logger.debug(
"Received message from %s with thread '%s' and body: %s",
msg.sender,
msg.thread,
@@ -28,23 +26,24 @@ class BeliefSetterBehaviour(CyclicBehaviour):
)
self._process_message(msg)
def _process_message(self, message: Message):
"""Routes the message to the correct processing function based on the sender."""
sender = message.sender.node # removes host from jid and converts to str
self.logger.debug("Processing message from sender: %s", sender)
self.agent.logger.debug("Processing message from sender: %s", sender)
match sender:
case settings.agent_settings.belief_collector_agent_name:
self.logger.debug("Message is from the belief collector agent. Processing as belief message.")
self.agent.logger.debug(
"Message is from the belief collector agent. Processing as belief message."
)
self._process_belief_message(message)
case _:
self.logger.debug("Not the belief agent, discarding message")
self.agent.logger.debug("Not the belief agent, discarding message")
pass
def _process_belief_message(self, message: Message):
if not message.body:
self.logger.debug("Ignoring message with empty body from %s", message.sender.node)
self.agent.logger.debug("Ignoring message with empty body from %s", message.sender.node)
return
match message.thread:
@@ -53,10 +52,10 @@ class BeliefSetterBehaviour(CyclicBehaviour):
beliefs: dict[str, list[str]] = json.loads(message.body)
self._set_beliefs(beliefs)
except json.JSONDecodeError:
self.logger.error(
self.agent.logger.error(
"Could not decode beliefs from JSON. Message body: '%s'",
message.body,
exc_info=True
exc_info=True,
)
case _:
pass
@@ -64,21 +63,23 @@ class BeliefSetterBehaviour(CyclicBehaviour):
def _set_beliefs(self, beliefs: dict[str, list[str]]):
"""Removes previous values for beliefs and updates them with the provided values."""
if self.agent.bdi is None:
self.logger.warning("Cannot set beliefs; agent's BDI is not yet initialized.")
self.agent.logger.warning("Cannot set beliefs; agent's BDI is not yet initialized.")
return
if not beliefs:
self.logger.debug("Received an empty set of beliefs. No beliefs were updated.")
self.agent.logger.debug("Received an empty set of beliefs. No beliefs were updated.")
return
# Set new beliefs (outdated beliefs are automatically removed)
for belief, arguments in beliefs.items():
self.logger.debug("Setting belief %s with arguments %s", belief, arguments)
self.agent.logger.debug("Setting belief %s with arguments %s", belief, arguments)
self.agent.bdi.set_belief(belief, *arguments)
# Special case: if there's a new user message, flag that we haven't responded yet
if belief == "user_said":
self.agent.bdi.set_belief("new_message")
self.logger.debug("Detected 'user_said' belief, also setting 'new_message' belief.")
self.agent.logger.debug(
"Detected 'user_said' belief, also setting 'new_message' belief."
)
self.logger.info("Successfully updated %d beliefs.", len(beliefs))
self.agent.logger.info("Successfully updated %d beliefs.", len(beliefs))

View File

@@ -1,5 +1,3 @@
import logging
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
@@ -9,16 +7,16 @@ class ReceiveLLMResponseBehaviour(CyclicBehaviour):
"""
Adds behavior to receive responses from the LLM Agent.
"""
logger = logging.getLogger(__name__)
async def run(self):
msg = await self.receive()
sender = msg.sender.node
sender = msg.sender.node
match sender:
case settings.agent_settings.llm_agent_name:
content = msg.body
self.logger.info("Received LLM response: %s", content)
#Here the BDI can pass the message back as a response
self.agent.logger.info("Received LLM response: %s", content)
# Here the BDI can pass the message back as a response
case _:
self.logger.debug("Discarding message from %s", sender)
pass
self.agent.logger.debug("Discarding message from %s", sender)
pass

View File

@@ -1,6 +1,4 @@
import asyncio
import json
import logging
from spade.behaviour import CyclicBehaviour
from spade.message import Message
@@ -9,8 +7,6 @@ from control_backend.core.config import settings
class BeliefFromText(CyclicBehaviour):
logger = logging.getLogger("Belief From Text")
# TODO: LLM prompt nog hardcoded
llm_instruction_prompt = """
You are an information extraction assistent for a BDI agent. Your task is to extract values from a user's text to bind a list of ungrounded beliefs. Rules:
@@ -25,16 +21,13 @@ class BeliefFromText(CyclicBehaviour):
"""
# on_start agent receives message containing the beliefs to look out for and sets up the LLM with instruction prompt
#async def on_start(self):
# async def on_start(self):
# msg = await self.receive(timeout=0.1)
# self.beliefs = dict uit message
# send instruction prompt to LLM
beliefs: dict[str, list[str]]
beliefs = {
"mood": ["X"],
"car": ["Y"]
}
beliefs = {"mood": ["X"], "car": ["Y"]}
async def run(self):
msg = await self.receive()
@@ -56,8 +49,8 @@ class BeliefFromText(CyclicBehaviour):
prompt = text_prompt + beliefs_prompt
self.logger.info(prompt)
#prompt_msg = Message(to="LLMAgent@whatever")
#response = self.send(prompt_msg)
# prompt_msg = Message(to="LLMAgent@whatever")
# response = self.send(prompt_msg)
# Mock response; response is beliefs in JSON format, it parses do dict[str,list[list[str]]]
response = '{"mood": [["happy"]]}'
@@ -65,15 +58,16 @@ class BeliefFromText(CyclicBehaviour):
try:
json.loads(response)
belief_message = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
body=response)
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=response,
)
belief_message.thread = "beliefs"
await self.send(belief_message)
self.logger.info("Sent beliefs to BDI.")
self.agent.logger.info("Sent beliefs to BDI.")
except json.JSONDecodeError:
# Parsing failed, so the response is in the wrong format, log warning
self.logger.warning("Received LLM response in incorrect format.")
self.agent.logger.warning("Received LLM response in incorrect format.")
async def _process_transcription_demo(self, txt: str):
"""
@@ -83,9 +77,12 @@ class BeliefFromText(CyclicBehaviour):
"""
belief = {"beliefs": {"user_said": [txt]}, "type": "belief_extraction_text"}
payload = json.dumps(belief)
belief_msg = Message(to=settings.agent_settings.belief_collector_agent_name
+ '@' + settings.agent_settings.host,
body=payload)
belief_msg = Message(
to=settings.agent_settings.belief_collector_agent_name
+ "@"
+ settings.agent_settings.host,
body=payload,
)
belief_msg.thread = "beliefs"
await self.send(belief_msg)

View File

@@ -1,8 +1,8 @@
from spade.agent import Agent
from control_backend.agents.base import BaseAgent
from .behaviours.text_belief_extractor import BeliefFromText
class TBeliefExtractorAgent(Agent):
class TBeliefExtractorAgent(BaseAgent):
async def setup(self):
self.add_behaviour(BeliefFromText())
self.add_behaviour(BeliefFromText())

View File

@@ -1,23 +1,22 @@
import json
import logging
from json import JSONDecodeError
from spade.behaviour import CyclicBehaviour
from spade.agent import Message
from spade.behaviour import CyclicBehaviour
from control_backend.core.config import settings
class ContinuousBeliefCollector(CyclicBehaviour):
"""
Continuously collects beliefs/emotions from extractor agents:
Then we send a unified belief packet to the BDI agent.
"""
logger = logging.getLogger(__name__)
async def run(self):
msg = await self.receive()
await self._process_message(msg)
async def _process_message(self, msg: Message):
sender_node = msg.sender.node
@@ -25,9 +24,8 @@ class ContinuousBeliefCollector(CyclicBehaviour):
try:
payload = json.loads(msg.body)
except JSONDecodeError as e:
self.logger.warning(
"Failed to parse JSON from %s. Body=%r Error=%s",
sender_node, msg.body, e
self.agent.logger.warning(
"Failed to parse JSON from %s. Body=%r Error=%s", sender_node, msg.body, e
)
return
@@ -35,19 +33,19 @@ class ContinuousBeliefCollector(CyclicBehaviour):
# Prefer explicit 'type' field
if msg_type == "belief_extraction_text" or sender_node == "belief_text_agent_mock":
self.logger.debug("Message routed to _handle_belief_text (sender=%s)", sender_node)
self.agent.logger.debug(
"Message routed to _handle_belief_text (sender=%s)", sender_node
)
await self._handle_belief_text(payload, sender_node)
#This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
self.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
# This is not implemented yet, but we keep the structure for future use
elif msg_type == "emotion_extraction_text" or sender_node == "emo_text_agent_mock":
self.agent.logger.debug("Message routed to _handle_emo_text (sender=%s)", sender_node)
await self._handle_emo_text(payload, sender_node)
else:
self.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.",
sender_node, msg_type
self.agent.logger.warning(
"Unrecognized message (sender=%s, type=%r). Ignoring.", sender_node, msg_type
)
async def _handle_belief_text(self, payload: dict, origin: str):
"""
Expected payload:
@@ -61,23 +59,20 @@ class ContinuousBeliefCollector(CyclicBehaviour):
beliefs = payload.get("beliefs", {})
if not beliefs:
self.logger.debug("Received empty beliefs set.")
self.agent.logger.debug("Received empty beliefs set.")
return
self.logger.debug("Forwarding %d beliefs.", len(beliefs))
self.agent.logger.debug("Forwarding %d beliefs.", len(beliefs))
for belief_name, belief_list in beliefs.items():
for belief in belief_list:
self.logger.debug(" - %s %s", belief_name,str(belief))
self.agent.logger.debug(" - %s %s", belief_name, str(belief))
await self._send_beliefs_to_bdi(beliefs, origin=origin)
async def _handle_emo_text(self, payload: dict, origin: str):
"""TODO: implement (after we have emotional recogntion)"""
pass
async def _send_beliefs_to_bdi(self, beliefs: list[str], origin: str | None = None):
"""
Sends a unified belief packet to the BDI agent.
@@ -90,6 +85,5 @@ class ContinuousBeliefCollector(CyclicBehaviour):
msg = Message(to=to_jid, sender=self.agent.jid, thread="beliefs")
msg.body = json.dumps(beliefs)
await self.send(msg)
self.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))
self.agent.logger.info("Sent %d belief(s) to BDI core.", len(beliefs))

View File

@@ -1,13 +1,11 @@
import logging
from spade.agent import Agent
from control_backend.agents.base import BaseAgent
from .behaviours.continuous_collect import ContinuousBeliefCollector
logger = logging.getLogger(__name__)
class BeliefCollectorAgent(Agent):
class BeliefCollectorAgent(BaseAgent):
async def setup(self):
logger.info("BeliefCollectorAgent starting (%s)", self.jid)
self.logger.info("BeliefCollectorAgent starting (%s)", self.jid)
# Attach the continuous collector behaviour (listens and forwards to BDI)
self.add_behaviour(ContinuousBeliefCollector())
logger.info("BeliefCollectorAgent ready.")
self.logger.info("BeliefCollectorAgent ready.")

View File

@@ -1,29 +1,22 @@
"""
LLM Agent module for routing text queries from the BDI Core Agent to a local LLM
service and returning its responses back to the BDI Core Agent.
"""
import logging
from typing import Any
import httpx
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from .llm_instructions import LLMInstructions
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from .llm_instructions import LLMInstructions
class LLMAgent(Agent):
class LLMAgent(BaseAgent):
"""
Agent responsible for processing user text input and querying a locally
hosted LLM for text generation. Receives messages from the BDI Core Agent
and responds with processed LLM output.
"""
logger = logging.getLogger("llm_agent")
class ReceiveMessageBehaviour(CyclicBehaviour):
"""
Cyclic behaviour to continuously listen for incoming messages from
@@ -63,8 +56,8 @@ class LLMAgent(Agent):
Sends a response message back to the BDI Core Agent.
"""
reply = Message(
to=settings.agent_settings.bdi_core_agent_name + '@' + settings.agent_settings.host,
body=msg
to=settings.agent_settings.bdi_core_agent_name + "@" + settings.agent_settings.host,
body=msg,
)
await self.send(reply)
self.agent.logger.info("Reply sent to BDI Core Agent")
@@ -78,35 +71,31 @@ class LLMAgent(Agent):
"""
async with httpx.AsyncClient(timeout=120.0) as client:
# Example dynamic content for future (optional)
instructions = LLMInstructions()
developer_instruction = instructions.build_developer_instruction()
response = await client.post(
settings.llm_settings.local_llm_url,
headers={"Content-Type": "application/json"},
json={
"model": settings.llm_settings.local_llm_model,
"messages": [
{
"role": "developer",
"content": developer_instruction
},
{
"role": "user",
"content": prompt
}
{"role": "developer", "content": developer_instruction},
{"role": "user", "content": prompt},
],
"temperature": 0.3
"temperature": 0.3,
},
)
try:
response.raise_for_status()
data: dict[str, Any] = response.json()
return data.get("choices", [{}])[0].get(
"message", {}
).get("content", "No response")
return (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "No response")
)
except httpx.HTTPError as err:
self.agent.logger.error("HTTP error: %s", err)
return "LLM service unavailable."

View File

@@ -1,9 +1,12 @@
import json
from spade.agent import Agent
from spade.behaviour import OneShotBehaviour
from spade.message import Message
from control_backend.core.config import settings
class BeliefTextAgent(Agent):
class SendOnceBehaviourBlfText(OneShotBehaviour):
async def run(self):
@@ -12,7 +15,15 @@ class BeliefTextAgent(Agent):
# Send multiple beliefs in one JSON payload
payload = {
"type": "belief_extraction_text",
"beliefs": {"user_said": ["hello test","Can you help me?","stop talking to me","No","Pepper do a dance"]}
"beliefs": {
"user_said": [
"hello test",
"Can you help me?",
"stop talking to me",
"No",
"Pepper do a dance",
]
},
}
msg = Message(to=to_jid)

View File

@@ -1,17 +1,15 @@
import json
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
import zmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from control_backend.schemas.ri_message import SpeechCommand
logger = logging.getLogger(__name__)
class RICommandAgent(Agent):
class RICommandAgent(BaseAgent):
subsocket: zmq.Socket
pubsocket: zmq.Socket
address = ""
@@ -47,13 +45,13 @@ class RICommandAgent(Agent):
# Send to the robot.
await self.agent.pubsocket.send_json(message.model_dump())
except Exception as e:
logger.error("Error processing message: %s", e)
self.logger.error("Error processing message: %s", e)
async def setup(self):
"""
Setup the command agent
"""
logger.info("Setting up %s", self.jid)
self.logger.info("Setting up %s", self.jid)
# To the robot
self.pubsocket = context.socket(zmq.PUB)
@@ -71,4 +69,4 @@ class RICommandAgent(Agent):
commands_behaviour = self.SendCommandsBehaviour()
self.add_behaviour(commands_behaviour)
logger.info("Finished setting up %s", self.jid)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,17 +1,16 @@
import asyncio
import logging
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
import zmq
import zmq
from spade.behaviour import CyclicBehaviour
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context
from .ri_command_agent import RICommandAgent
logger = logging.getLogger(__name__)
class RICommunicationAgent(Agent):
class RICommunicationAgent(BaseAgent):
req_socket: zmq.Socket
_address = ""
_bind = True
@@ -45,13 +44,13 @@ class RICommunicationAgent(Agent):
message = await asyncio.wait_for(self.agent.req_socket.recv_json(), timeout=3.0)
# We didnt get a reply :(
except asyncio.TimeoutError as e:
logger.info("No ping retrieved in 3 seconds, killing myself.")
except TimeoutError:
self.agent.logger.info("No ping retrieved in 3 seconds, killing myself.")
self.kill()
logger.debug('Received message "%s"', message)
self.agent.logger.debug('Received message "%s"', message)
if "endpoint" not in message:
logger.error("No received endpoint in message, excepted ping endpoint.")
self.agent.logger.error("No received endpoint in message, excepted ping endpoint.")
return
# See what endpoint we received
@@ -59,7 +58,7 @@ class RICommunicationAgent(Agent):
case "ping":
await asyncio.sleep(1)
case _:
logger.info(
self.agent.logger.info(
"Received message with topic different than ping, while ping expected."
)
@@ -67,7 +66,7 @@ class RICommunicationAgent(Agent):
"""
Try to setup the communication agent, we have 5 retries in case we dont have a response yet.
"""
logger.info("Setting up %s", self.jid)
self.logger.info("Setting up %s", self.jid)
retries = 0
# Let's try a certain amount of times before failing connection
@@ -86,8 +85,8 @@ class RICommunicationAgent(Agent):
try:
received_message = await asyncio.wait_for(self.req_socket.recv_json(), timeout=20.0)
except asyncio.TimeoutError:
logger.warning(
except TimeoutError:
self.logger.warning(
"No connection established in 20 seconds (attempt %d/%d)",
retries + 1,
max_retries,
@@ -96,7 +95,7 @@ class RICommunicationAgent(Agent):
continue
except Exception as e:
logger.error("Unexpected error during negotiation: %s", e)
self.logger.error("Unexpected error during negotiation: %s", e)
retries += 1
continue
@@ -104,7 +103,7 @@ class RICommunicationAgent(Agent):
endpoint = received_message.get("endpoint")
if endpoint != "negotiate/ports":
# TODO: Should this send a message back?
logger.error(
self.logger.error(
"Invalid endpoint '%s' received (attempt %d/%d)",
endpoint,
retries + 1,
@@ -143,10 +142,10 @@ class RICommunicationAgent(Agent):
)
await ri_commands_agent.start()
case _:
logger.warning("Unhandled negotiation id: %s", id)
self.logger.warning("Unhandled negotiation id: %s", id)
except Exception as e:
logger.error("Error unpacking negotiation data: %s", e)
self.logger.error("Error unpacking negotiation data: %s", e)
retries += 1
continue
@@ -154,10 +153,10 @@ class RICommunicationAgent(Agent):
break
else:
logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
self.logger.error("Failed to set up RICommunicationAgent after %d retries", max_retries)
return
# Set up ping behaviour
listen_behaviour = self.ListenBehaviour()
self.add_behaviour(listen_behaviour)
logger.info("Finished setting up %s", self.jid)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,21 +1,19 @@
import asyncio
import logging
import numpy as np
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from spade.message import Message
from .speech_recognizer import SpeechRecognizer
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
from .speech_recognizer import SpeechRecognizer
class TranscriptionAgent(Agent):
class TranscriptionAgent(BaseAgent):
"""
An agent which listens to audio fragments with voice, transcribes them, and sends the
transcription to other agents.
@@ -47,7 +45,8 @@ class TranscriptionAgent(Agent):
"""Share a transcription to the other agents that depend on it."""
receiver_jids = [
settings.agent_settings.text_belief_extractor_agent_name
+ '@' + settings.agent_settings.host,
+ "@"
+ settings.agent_settings.host,
] # Set message receivers here
for receiver_jid in receiver_jids:
@@ -58,7 +57,7 @@ class TranscriptionAgent(Agent):
audio = await self.audio_in_socket.recv()
audio = np.frombuffer(audio, dtype=np.float32)
speech = await self._transcribe(audio)
logger.info("Transcribed speech: %s", speech)
self.agent.logger.info("Transcribed speech: %s", speech)
await self._share_transcription(speech)
@@ -73,7 +72,7 @@ class TranscriptionAgent(Agent):
self.audio_in_socket.connect(self.audio_in_address)
async def setup(self):
logger.info("Setting up %s", self.jid)
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
@@ -81,4 +80,4 @@ class TranscriptionAgent(Agent):
transcribing.warmup()
self.add_behaviour(transcribing)
logger.info("Finished setting up %s", self.jid)
self.logger.info("Finished setting up %s", self.jid)

View File

@@ -1,17 +1,14 @@
import logging
import numpy as np
import torch
import zmq
import zmq.asyncio as azmq
from spade.agent import Agent
from spade.behaviour import CyclicBehaviour
from .transcription.transcription_agent import TranscriptionAgent
from control_backend.agents import BaseAgent
from control_backend.core.config import settings
from control_backend.core.zmq_context import context as zmq_context
logger = logging.getLogger(__name__)
from .transcription.transcription_agent import TranscriptionAgent
class SocketPoller[T]:
@@ -60,7 +57,9 @@ class Streaming(CyclicBehaviour):
data = await self.audio_in_poller.poll()
if data is None:
if len(self.audio_buffer) > 0:
logger.debug("No audio data received. Discarding buffer until new data arrives.")
self.agent.logger.debug(
"No audio data received. Discarding buffer until new data arrives."
)
self.audio_buffer = np.array([], dtype=np.float32)
self.i_since_speech = 100
return
@@ -71,7 +70,7 @@ class Streaming(CyclicBehaviour):
if prob > 0.5:
if self.i_since_speech > 3:
logger.debug("Speech started.")
self.agent.logger.debug("Speech started.")
self.audio_buffer = np.append(self.audio_buffer, chunk)
self.i_since_speech = 0
return
@@ -84,7 +83,7 @@ class Streaming(CyclicBehaviour):
# Speech probably ended. Make sure we have a usable amount of data.
if len(self.audio_buffer) >= 3 * len(chunk):
logger.debug("Speech ended.")
self.agent.logger.debug("Speech ended.")
await self.audio_out_socket.send(self.audio_buffer[: -2 * len(chunk)].tobytes())
# At this point, we know that the speech has ended.
@@ -92,7 +91,7 @@ class Streaming(CyclicBehaviour):
self.audio_buffer = chunk
class VADAgent(Agent):
class VADAgent(BaseAgent):
"""
An agent which listens to an audio stream, does Voice Activity Detection (VAD), and sends
fragments with detected speech to other agents over ZeroMQ.
@@ -135,12 +134,12 @@ class VADAgent(Agent):
self.audio_out_socket = zmq_context.socket(zmq.PUB)
return self.audio_out_socket.bind_to_random_port("tcp://*", max_tries=100)
except zmq.ZMQBindError:
logger.error("Failed to bind an audio output socket after 100 tries.")
self.logger.error("Failed to bind an audio output socket after 100 tries.")
self.audio_out_socket = None
return None
async def setup(self):
logger.info("Setting up %s", self.jid)
self.logger.info("Setting up %s", self.jid)
self._connect_audio_in_socket()
@@ -157,4 +156,4 @@ class VADAgent(Agent):
transcriber = TranscriptionAgent(audio_out_address)
await transcriber.start()
logger.info("Finished setting up %s", self.jid)
self.logger.info("Finished setting up %s", self.jid)