Persisting logs in DB for each task in Celery (+ FastAPI)

At FereAI.xyz, we use celery at scale. Each time the trading agent wants to perform a trade or a manual sell is executed or emails are supposed to be sent for scheduled jobs. All these actions happen via celery.

Celery is great, but I always miss the feature from airflow where you can see the logs of each individual task run, and then be able to diagnose or debug something. So, I decided to build something around it.

Requirements

I wanted a solution where

  • All logs from celery jobs are stored in DB along with their task id, status & a few other params
  • Logging works otb for celery, fastapi & standalone usage (jupyter notebooks)

Here’s what I did

A database table for storing logs

from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

from sqlalchemy import Column, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from datetime import datetime
from sqlalchemy.sql.expression import func
from sqlalchemy import DateTime
from .base import Base


class BatchedLog(Base):
  __tablename__ = "batched_log"
  id = Column(Integer, primary_key=True, autoincrement=True)
  task_id = Column(String(255), nullable=True, index=True)
  task_name = Column(String(255), nullable=True, index=True)
  status = Column(String(255), nullable=True, index=True)
  logs = Column(Text)  # Store logs as JSON
  created_at: Mapped[datetime] = mapped_column(
    DateTime(timezone=True),
    nullable=False,
    default=func.now(),
    index=True,
  )

Logger config

import contextvars
import logging
import os

from celery.app.log import TaskFormatter

# Create and configure the logger
logger = logging.getLogger("trader")
logger.setLevel(logging.DEBUG)

# Remove existing handlers to avoid duplication
if logger.hasHandlers():
  logger.handlers.clear()

# Stream handler for outputting to stdout
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)

# Formatter for the log messages
formatter = TaskFormatter(
  "%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(threadName)s- %(levelname)s - %(message)s"
)

logger.addHandler(stream_handler)

# Prevent propagation to the root logger
logger.propagate = False

task_logs = contextvars.ContextVar("task_logs", default=[])

# Define a context variable to store the logger
logger_context = contextvars.ContextVar("logger_context", default=None)


def get_logger():
  """Retrieve the current logger from the context variable."""
  lc = logger_context.get()
  if lc is None:
    logger.name = "Jupyter" if "ipykernel" in os.environ else "Standalone"
    return logger
  return lc

Celery app

import json
import logging
import traceback
from uuid import UUID

from celery.app import Celery
from celery.app.log import TaskFormatter
from celery.schedules import crontab
from celery.signals import after_setup_task_logger
from celery.utils.log import get_task_logger
from celery.signals import task_prerun, task_postrun, task_failure
from celery.result import AsyncResult

from .logger_config import task_logs, logger_context, get_logger

app = Celery("trader", broker=get_redis_url(), backend=get_redis_url())
app.conf.task_logging_level = logging.DEBUG

# Signal to initialize logs
@task_prerun.connect
def initialize_logs(task_id=None, task=None, args=None, kwargs=None, **extras):
  logger_context.set(celery_task_logger)
  task_logs.set([])
  celery_task_logger.info(f"Logger set for Celery task: {task_id}")


# Signal to persist logs on task completion
@task_postrun.connect
def persist_logs_on_completion(
  task_id=None, task=None, args=None, kwargs=None, retval=None, **extras
):

  result = AsyncResult(task_id)
  status = result.status

  task_name = task.name if task else "unknown"

  logs = task_logs.get()
  print(f"Persisting logs for task_id={task_id}")

  # Save logs to the database
  with get_trade_db_session() as session:
    session.merge(
      BatchedLog(
        task_id=task_id,
        status=status.lower(),
        task_name=task_name,
        logs=json.dumps(logs),
      )
    )
    session.commit()
  logger_context.set(None)

# Signal to handle task failure
@task_failure.connect
def handle_failure(
  task_id=None,
  task=None,
  args=None,
  kwargs=None,
  exc=None,
  traceback=None,
  **extras,
):
  logger = get_logger()
  result = AsyncResult(task_id)
  status = result.status

  task_name = task.name if task else "unknown"

  logger.error(f"Task failed: {exc}")
  logger.error(
    "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
  )

  logs = task_logs.get()
  print(f"Task failed. Persisting logs for task_id={task_id}")

  # Save logs to the database
  with get_trade_db_session() as session:
    session.merge(
      BatchedLog(
        task_id=task_id,
        status=status.lower(),
        task_name=task_name,
        logs=json.dumps(logs),
      )
    )
    session.commit()

class TaskLogHandler(logging.Handler):
  def emit(self, record):
    logs = task_logs.get()
    logs.append(
      {
        "level": record.levelname,
        "message": record.getMessage(),
        "time": record.created,
        "filename": record.filename,
        "lineno": record.lineno,
        "funcName": record.funcName,
        "thread": record.threadName,
      }
    )
    task_logs.set(logs)


@after_setup_task_logger.connect
def setup_task_logger(logger, *args, **kwargs):
  task_handler = TaskLogHandler()
  logger.addHandler(task_handler)
  for handler in logger.handlers:
    handler.setFormatter(
      TaskFormatter(
        "%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(levelname)s - %(message)s"
      )
    )

@app.task
def test_task():
  celery_task_logger.info("Starting main task")
  try:
    celery_task_logger.info(
      f"Started {len(agents_started)} agents with UUIDs: {agents_started}"
    )
    return agents_started
  except Exception as e:
    celery_task_logger.error(f"Task failed: {e}")
    celery_task_logger.error(traceback.format_exc())
    raise

Any other custom claases or functions used should also use the get_logger

from .logger_config import get_logger

class MyClass:

  def __init__(self):
    self.logger = get_logger()
  
  def foo(self):
    self.logger.info("Inside foo")

fastapi app

app = FastAPI(
...
)

@app.middleware("http")
async def set_fastapi_logger(request: Request, call_next):
  # Assign a request-specific logger to the logger context
  if not hasattr(request.state, "logger"):
    request.state.logger = get_logger()
  logger_context.set(request.state.logger)
  response = await call_next(request)
  # Clear the logger context after request handling
  logger_context.set(None)
  return response

Leave a Reply