Using websockets with Autogen

Autogen provides a default implementation of Websockets. It’s available at https://microsoft.github.io/autogen/docs/notebooks/agentchat_websockets. However, it has some limitations. Notably

  1. Isn’t compatible with ASGI Servers running multiple instances of FastAPI Server.

For a production grade deployment, one wants to have many instances of FastAPI or Django being served over an ASGI server like uvicorn.

So, this notebook demonstrates how to build an alternative approach that works seamlessly and scales up as the demand grows.

Requirements

Some extra dependencies are needed for this notebook, which can be installed via pip:

pip install pyautogen[websockets] fastapi uvicorn

Define your Agents

agent = autogen.ConversableAgent(
        name="chatbot",
        system_message="Complete a task given to you and reply TERMINATE when the task is done. If asked about the weather, use tool 'weather_forecast(city)' to get the weather forecast for a city.",
        llm_config={
            "config_list": autogen.config_list_from_json(
                env_or_file="OAI_CONFIG_LIST",
                filter_dict={
                    "model": ["gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
                },
            ),
            "stream": True,
        },
    )

user_proxy = autogen.UserProxyAgent(
        name="user_proxy",
        system_message="A proxy for the user.",
        is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
        human_input_mode="NEVER",
        max_consecutive_auto_reply=10,
        code_execution_config=False,
    )

def weather_forecast(city: str) -> str:
        return f"The weather forecast for {city} at {datetime.now()} is sunny."

autogen.register_function(
        weather_forecast, caller=agent, executor=user_proxy, description="Weather forecast for a city"
    )

Create a custom IOStream that handles websocket connections

from autogen.io.base import IOStream
from fastapi import WebSocket

class CustomIOWebsockets(IOStream):
  r"""A websocket input/output stream.

  Attributes:
      _websocket (WebSocket): The websocket server.
  """

  def __init__(self, websocket: WebSocket) -> None:
    """Initialize the websocket input/output stream.

    Args:
        websocket (ServerConnection): The websocket server.

    Raises:
        ImportError: If the websockets module is not available.
    """
    self._websocket = websocket

  @staticmethod
  async def handler(websocket: WebSocket, on_connect, *args, **kwargs) -> None:
    """The handler function for the websocket server."""
    logger.debug(
      f" - CustomIOWebsockets._handler(): Client connected on {websocket}"
    )
    # create a new IOWebsockets instance using the websocket that is
    # create when a client connects
    iowebsocket = CustomIOWebsockets(websocket)
    with CustomIOWebsockets.set_default(iowebsocket):
      # call the on_connect function
      await on_connect(iowebsocket, *args, **kwargs)

  @property
  def websocket(self) -> "WebSocket":
    """The URI of the websocket server."""
    return self._websocket


  def print(
    self, *objects: Any, sep: str = " ", end: str = "\n", flush: bool = False
  ) -> None:
    r"""Print data to the output stream.

    Args:
        objects (any): The data to print.
        sep (str, optional): The separator between objects.
        Defaults to " ".
        end (str, optional): The end of the output. Defaults to "\n".
        flush (bool, optional): Whether to flush the output.
        Defaults to False.
    """
    if isinstance(objects, tuple) and isinstance(objects[0], dict):
      _xs = sep.join(map(json.dumps, objects)) + end
    else:
      _xs = sep.join(map(str, objects)) + end
    if _xs:
      xs = {"type": "websocket.send", "text": _text_to_send}

      def send_async():
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        try:
          loop.run_until_complete(self._websocket.send(xs))
        finally:
          loop.close()

      thread = threading.Thread(target=send_async)
      thread.start()
      thread.join()

  async def input(self, prompt: str = "", *, password: bool = False) -> str:
    """Read a line from the input stream.

    Args:
        prompt (str, optional): The prompt to display. Defaults to "".
        password (bool, optional): Whether to read a password.
        Defaults to False.

    Returns:
        str: The line read from the input stream.
    """
    if prompt != "":
      await self._websocket.send(prompt)

    msg = await self._websocket.receive_text()

    return msg.decode("utf-8") if isinstance(msg, bytes) else msg

Define your on_connect handler

async def chat_on_connect(
  iostream: CustomIOWebsockets,
) -> None:
  logger.debug(
    " - on_connect(): Connected to client using CustomIOWebsockets "
    f"{iostream}"
  )

  logger.debug(" - on_connect(): Receiving message from client.")

  msg = json.loads(await iostream.input())  # Await the input method
  query = msg["query"]
  user_proxy.initiate_chat(  # noqa: F704
        agent,
        message=initial_msg,
    )
  

Adding this with FastAPI

from fastapi.websockets import WebSocketState
from fastapi import WebSocket
from starlette.websockets import WebSocketDisconnect

app = FastAPI(
  title="WebSockets with Autogen",
  description="""A better websocket with Autogen""",
  version="1.0.1"
)

@app.websocket("/chat")
async def websocket_endpoint_v2(
  websocket: WebSocket,
):
  await websocket.accept()
  try:
    origin = websocket.headers.get("origin")
    await CustomIOWebsockets.handler(
      websocket,
      chat_on_connect,
    )
  except WebSocketDisconnect:
    logger.info("Client disconnected")
    pass
  except Exception as e:
    logger.exception(f"An error occurred: {e}", exc_info=True)
    # send the error message to the client
    if websocket.client_state == WebSocketState.CONNECTED:
      try:
        await websocket.send_text("An internal server error occurred. Closing.")
      except RuntimeError:
        pass
  finally:
    if websocket.client_state == WebSocketState.CONNECTED:
      try:
        await websocket.close()
      except RuntimeError:
        pass

Start your ASGI

uvicorn server:app  --workers 3 --ws auto

Leave a Reply