159 lines
5.4 KiB
Python
159 lines
5.4 KiB
Python
import contextlib
|
|
import logging
|
|
import threading
|
|
|
|
import zmq
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from zmq.asyncio import Context
|
|
|
|
from control_backend.agents import (
|
|
BeliefCollectorAgent,
|
|
LLMAgent,
|
|
RICommunicationAgent,
|
|
VADAgent,
|
|
)
|
|
from control_backend.agents.bdi import BDICoreAgent, TBeliefExtractorAgent
|
|
from control_backend.api.v1.router import api_router
|
|
from control_backend.core.config import settings
|
|
from control_backend.logging import setup_logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def setup_sockets():
|
|
context = Context.instance()
|
|
|
|
internal_pub_socket = context.socket(zmq.XPUB)
|
|
internal_pub_socket.bind(settings.zmq_settings.internal_sub_address)
|
|
logger.debug("Internal publishing socket bound to %s", internal_pub_socket)
|
|
|
|
internal_sub_socket = context.socket(zmq.XSUB)
|
|
internal_sub_socket.bind(settings.zmq_settings.internal_pub_address)
|
|
logger.debug("Internal subscribing socket bound to %s", internal_sub_socket)
|
|
try:
|
|
zmq.proxy(internal_sub_socket, internal_pub_socket)
|
|
except zmq.ZMQError:
|
|
logger.warning("Error while handling PUB/SUB proxy. Closing sockets.")
|
|
finally:
|
|
internal_pub_socket.close()
|
|
internal_sub_socket.close()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
Application lifespan context manager to handle startup and shutdown events.
|
|
"""
|
|
# --- APPLICATION STARTUP ---
|
|
setup_logging()
|
|
logger.info("%s is starting up.", app.title)
|
|
logger.warning("testing extra", extra={"extra1": "one", "extra2": "two"})
|
|
|
|
# Initiate sockets
|
|
proxy_thread = threading.Thread(target=setup_sockets)
|
|
proxy_thread.daemon = True
|
|
proxy_thread.start()
|
|
|
|
context = Context.instance()
|
|
|
|
endpoints_pub_socket = context.socket(zmq.PUB)
|
|
endpoints_pub_socket.connect(settings.zmq_settings.internal_pub_address)
|
|
app.state.endpoints_pub_socket = endpoints_pub_socket
|
|
|
|
# --- Initialize Agents ---
|
|
logger.info("Initializing and starting agents.")
|
|
agents_to_start = {
|
|
"RICommunicationAgent": (
|
|
RICommunicationAgent,
|
|
{
|
|
"name": settings.agent_settings.ri_communication_agent_name,
|
|
"jid": f"{settings.agent_settings.ri_communication_agent_name}"
|
|
f"@{settings.agent_settings.host}",
|
|
"password": settings.agent_settings.ri_communication_agent_name,
|
|
"address": "tcp://*:5555",
|
|
"bind": True,
|
|
},
|
|
),
|
|
"LLMAgent": (
|
|
LLMAgent,
|
|
{
|
|
"name": settings.agent_settings.llm_agent_name,
|
|
"jid": f"{settings.agent_settings.llm_agent_name}@{settings.agent_settings.host}",
|
|
"password": settings.agent_settings.llm_agent_name,
|
|
},
|
|
),
|
|
"BDICoreAgent": (
|
|
BDICoreAgent,
|
|
{
|
|
"name": settings.agent_settings.bdi_core_agent_name,
|
|
"jid": f"{settings.agent_settings.bdi_core_agent_name}@"
|
|
f"{settings.agent_settings.host}",
|
|
"password": settings.agent_settings.bdi_core_agent_name,
|
|
"asl": "src/control_backend/agents/bdi/rules.asl",
|
|
},
|
|
),
|
|
"BeliefCollectorAgent": (
|
|
BeliefCollectorAgent,
|
|
{
|
|
"name": settings.agent_settings.belief_collector_agent_name,
|
|
"jid": f"{settings.agent_settings.belief_collector_agent_name}@"
|
|
f"{settings.agent_settings.host}",
|
|
"password": settings.agent_settings.belief_collector_agent_name,
|
|
},
|
|
),
|
|
"TBeliefExtractor": (
|
|
TBeliefExtractorAgent,
|
|
{
|
|
"name": settings.agent_settings.text_belief_extractor_agent_name,
|
|
"jid": f"{settings.agent_settings.text_belief_extractor_agent_name}@"
|
|
f"{settings.agent_settings.host}",
|
|
"password": settings.agent_settings.text_belief_extractor_agent_name,
|
|
},
|
|
),
|
|
"VADAgent": (
|
|
VADAgent,
|
|
{"audio_in_address": "tcp://localhost:5558", "audio_in_bind": False},
|
|
),
|
|
}
|
|
|
|
for name, (agent_class, kwargs) in agents_to_start.items():
|
|
try:
|
|
logger.debug("Starting agent: %s", name)
|
|
agent_instance = agent_class(**{k: v for k, v in kwargs.items() if k != "name"})
|
|
await agent_instance.start()
|
|
logger.info("Agent '%s' started successfully.", name)
|
|
except Exception as e:
|
|
logger.error("Failed to start agent '%s': %s", name, e, exc_info=True)
|
|
# Consider if the application should continue if an agent fails to start.
|
|
raise
|
|
|
|
logger.info("Application startup complete.")
|
|
|
|
yield
|
|
|
|
# --- APPLICATION SHUTDOWN ---
|
|
logger.info("%s is shutting down.", app.title)
|
|
|
|
# Potential shutdown logic goes here
|
|
|
|
logger.info("Application shutdown complete.")
|
|
|
|
|
|
# if __name__ == "__main__":
|
|
app = FastAPI(title=settings.app_title, lifespan=lifespan)
|
|
|
|
# This middleware allows other origins to communicate with us
|
|
app.add_middleware(
|
|
CORSMiddleware, # https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS
|
|
allow_origins=[settings.ui_url], # address of our UI application
|
|
allow_methods=["*"], # GET, POST, etc.
|
|
)
|
|
|
|
app.include_router(api_router, prefix="") # TODO: make prefix /api/v1
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"status": "ok"}
|