API Testing in FastAPI with Dependency Injection and Mock Objects

Mocking and dependency injection are tools that can help you test API endpoints without sending emails or connecting to a live database!

With these tools, you can automatically confirm that your app responds to requests with different headers and cookies.

In this post, I will show how to create mocks for a SQLAlchemy database and a custom email interface. I’ll also demonstrate how to simulate API calls while using the mock objects with pytest.

The Example API Endpoint: User Registration

In this endpoint, the database and email connections are inline. I’ll show you how to pull the connections out of the method and swap them with mock objects while testing.

@app.post("/auth/signup", status_code=201)
async def signup(req: models.Signup):
    with db.engine.connect() as conn:
        with conn.begin():
            user = db.get_user_with_email(conn, req.user.email)
            if user is not None:
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail=...
                )
            if not auth.check_password_strength(req.user.password):
                raise HTTPException(
                    status_code=status.HTTP_400_BAD_REQUEST,
                    detail=...
                )
            user_id = account_lib.create_user(conn, req.user)
            acct_pk, acct_pid = account_lib.create_site_account(
                conn=conn,
                user_id=user_id,
                site_account=req.site_account
            )

    with email_utils.MailServerConnection() as mail_conn:
        msg = email_utils.HtmlEmailMessage()
        msg['To'] = req.user.email
        msg['From'] = config.SUPPORT_EMAIL_ADDRESS
        msg['Subject'] = ...
        msg.set_html_body(...)
        mail_conn.send_message(msg)
    )

    return auth.create_token_response(req.user.email)

What is Mocking?

When you write automated tests, you don’t want to set up an email server or a database just to test your code. That’s where mocking comes in.

A “mock” object simply implements the same interface as another object. Inside of the mock object, the functionality is stripped out. It pretends to do the job it’s supposed to do.

Mock databases pretend to store data in a database. Mock email clients pretend to send emails. Mock file interfaces pretend to write to the hard drive. Mock SDKs pretend to send requests to another system (moto for AWS is a great example of a mock SDK).

Here’s an illustration:

Some Mocking Examples:

An SQLAlchemy Mock Database Connection

If you use SQLAlchemy, you’re probably creating the database engine like this:

# db module
import sqlalchemy as sa

engine = create_engine(config.SQLALCHEMY_DB_URL, echo=True)

And to get a connection to the database, you might call something like this:

with db.engine.connect() as conn:
    with conn.begin():
        user = db.get_user_with_email(conn, create_user_req.email)

SQLAlchemy makes it easy to create a mock database. Just specify that you want the database in memory.

With pytest, we can create a test fixture to create the mock database:

@pytest.fixture()
def mock_db_conn():
    mock_engine = sa.create_engine('sqlite:///:memory:', echo=True)
    db.metadata.create_all(mock_engine)
    with mock_engine.connect() as conn:
        with conn.begin():
            yield conn
    db.metadata.drop_all(mock_engine)


# example usage
def test_get_user(mock_db_conn):
    user = db.get_user_with_email(mock_db_conn, 'spam@example.com')

You can see that the test creates the db engine in memory with the sqlite:///:memory: connection string.

It then uses the database metadata to create the database schema in the new engine before the test.

Then, it opens and yields a connection before closing it and dropping the database.

Tearing down and rebuilding the database this way ensures that each test runs in a clean environment. We don’t want tests seeing data that was created by other tests.

A Mock Mail Server Connection

We are lucky with this mock database, because SQLAlchemy makes it easy, but what about when we have to build one ourselves?

My mail server connection doesn’t provide a mock object. So, I have to write one myself.

Here is how it is used again:

with email_utils.MailServerConnection() as mail_conn:
    msg = email_utils.HtmlEmailMessage()
    msg['To'] = req.user.email
    msg['From'] = config.SUPPORT_EMAIL_ADDRESS
    msg['Subject'] = ...
    msg.set_html_body(...)
    mail_conn.send_message(msg)

The MailServerConnection is a context manager that opens and closes a connection to the mail server. Its __enter__ method returns an object that implements a send_message() method.

So, we just make another class that implements the same method.

class MockMailServer:
    sent_messages = []

    def send_message(self, msg):
        sent_messages.append(msg)


@pytest.fixture()
def mock_mail_conn():
    return MockMailServer()

Dependency Injection: How to Use Mock Objects

Now we have the mock objects, but we need FastAPI to use the mocks instead of live connections during the API tests.

FastAPI provides a dependency injection system to help us accomplish that.

How to Set Up Dependency Injection

To use Dependency Injection in FastAPI, you mark an item for injection and then use the result as a default argument to a method or constructor.

Let’s mark the database and email connections as dependencies for our API endpoint from before.

#!./deps.py

def db_connection():
    with db.engine.connect() as conn:
        with conn.begin():
            yield conn


# `Annotated` shows the type of an object to the IDE
DBConnectionDep = Annotated[sqlalchemy.engine.Connection, Depends(db_connection)]


def mail_connection():
    with email_utils.MailServerConnection() as conn:
        yield conn

# It isn't necessary to annotate the dependencies, though.
#!./app.py
import deps
...

@app.post("/auth/signup", status_code=201)
async def signup(
    req: models.Signup,
    db_conn: deps.DBConnectionDep,           # annotated
    mail_conn=Depends(deps.mail_connection)  # not annotated
):
    user = db.get_user_with_email(db_conn, req.user.email)
    if user is not None:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=...
        )

    if not auth.check_password_strength(req.user.password):
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=...
        )

    user_id = account_lib.create_user(db_conn, req.user)
    acct_pk, acct_pid = account_lib.create_site_account(
        conn=db_conn,
        user_id=user_id,
        site_account=req.site_account
    )

    msg = email_utils.HtmlEmailMessage()
    msg['To'] = req.user.email
    msg['From'] = config.SUPPORT_EMAIL_ADDRESS
    msg['Subject'] = ...
    msg.set_html_body(...)
    mail_conn.send_message(msg)

    return auth.create_token_response(req.user.email)

These are two examples of adding a dependency:

  • For the database, the dependency is annotated with the Connection class type and used as a type hint (note the colon). This is more complex but will provide better type hinting in an IDE.
  • For the mail server, the dependency is wrapped with Depends and passed as a default argument (note the equals sign). This is simpler but will provide less detail to IDE.

Both approaches are valid.

The FastAPI docs only outline how to use dependency injection, but they don’t explain how it works. That’s not enough for me. So, I dug through the source code.

How Does Dependency Injection Work?

In fastapi.dependencies.utils, there is a function called solve_dependencies. I stripped out everything but the essentials and pasted it below:

async def solve_dependencies(
    *,
    ...
    dependant: Dependant,
    # The provider is the "app" or FastAPI object
    dependency_overrides_provider = None, 
    dependency_cache = None,
    ...
):
    ...
    dependency_cache = dependency_cache or {}
    for sub_dependant in dependant.dependencies:
        ...
        solved_result = await solve_dependencies(
            ...
            dependant=use_sub_dependant,
            dependency_overrides_provider=dependency_overrides_provider,
            dependency_cache=dependency_cache,
            ...
        )
        ...
        dependency_cache.update(...)
        ...

You can see this method recursively loops through everything marked as a dependency and resolves it.

This is called by the router when a client sends a request to the API. The router checks the signature of the API method for dependencies and looks them up from a cache. If they aren’t there, it resolves them. The cache is cleared between requests and can be bypassed by providing use_cache=False to the Dependency constructor. Note that the router checks for a “dependency overrides provider” while resolving dependencies.

Replacing Dependencies with Mocks

There are at least two ways to test the API methods at this point. I’ll start with the simplest.

Calling API Methods Directly and Passing in the Mocks

import pytest
import json

# plugin required to test async functions
pytest_plugins = ('pytest_asyncio',)

@pytest.mark.asyncio
async def test_app_directly(mock_db_conn, mock_email_conn):
    response = await app_lib.signup(
        req=models.Signup(
            user=models.CreateUser(
                email=...,
                password=...,
                ...
            ),
            site_account=models.CreateSiteAccount(
                ...
            ),
            posts=[],
        ),
        db_conn=mock_db_conn,
        mail_conn=mock_email_conn
    )
    assert response.status_code == 200
    response_json = json.loads(response.body)
    assert 'access_token' in response_json
    assert 'refresh_token' in response_json

This is the most straightforward way to test the API endpoint. Unfortunately, it requires you to know which model classes to pass to the method. It isn’t obvious how you might test different headers or cookies.

Simulating API Calls

If you stop the article right here. You will be off to a good start with testing the project already. This next strategy is powerful, but it has some gotchas to look out for. I will show you how to work around them.

FastAPI and HTTPX provide a TestClient that can simulate HTTP requests to the FastAPI instance. This gives you the power to change headers and cookies. It also provides more isolation between the tests and the code. To use it, wrap the app instance in TestClient and call get or post on it.

from fastapi.testclient import TestClient

client = TestClient(app_lib.app)

def test_robots():
    response = client.get('/robots.txt')
    assert response.status_code == 200
    assert response.text.startswith('User-agent')

Remember what we said before about the dependency injector looking for a “dependency overrides provider?” You can override the dependencies in the test suite. Below is an example test class using TestClient and dependency overrides with some gotchas mentioned in the comments. I will go over each of them in detail.

#!./deps.py

# Gotcha 1: Define __hash__ and __eq__ on all dependencies
#  Otherwise, the dependency override will fail while testing

# In this example, I've written __hash__ and __eq__ manually
def db_dependency():
    with db.engine.connect() as conn:
        with conn.begin():
            yield conn

db_dependency.__hash__ = lambda self:  hash('server.deps.db_dependency')
db_dependency.__eq__ = lambda self, other: other == db_dependency
DBConnectionDep = Annotated[sqlalchemy.engine.Connection, Depends(db_dependency)]


# In this example, I am leveraging @dataclass to define those methods for me.
# I like this option, because there is less room for error.
@dataclass(frozen=True, eq=True)
class MailConnectionCallable:
    def __call__(self):  # note: *args and **kwargs are not used.
        with email_utils.MailServerConnection() as conn:
            yield conn


MailConnectionDep = Annotated[smtplib.SMTP_SSL, Depends(MailConnectionCallable())]
#!test/test_app.py

from server import app as app_lib 
from server import deps

from fastapi.testclient import TestClient
from test.utils import mock_db_conn, mock_email_conn  # these are test fixtures

client = TestClient(app_lib.app)
# Gotcha 2: unwrap test fixtures in order to call them directly
app_lib.app.dependency_overrides[deps.db_dependency] = mock_db_conn.__wrapped__
app_lib.app.dependency_overrides[deps.MailConnectionCallable()] = mock_email_conn.__wrapped__


def test_signup():
    response = client.post('/auth/signup', json={
        'user': {
            'email': ...,
            'password': ...,
            ...
        },
        'site_account': {
            ...
        },
        'posts': [],
    })
    assert response.status_code == 200
    response_json = response.json()
    assert 'access_token' in response_json
    assert 'refresh_token' in response_json
#!test/utils.py

@pytest.fixture()
def mock_db_conn():
    # Gotcha 3: the thread failsafe (check_same_thread) for SQLite
    #   is disabled. This allows the database to be created and
    #   used in different processes or threads, which can happen
    #   while testing with TestClient.
    mock_engine = sa.create_engine(
        'sqlite:///:memory:',
        echo=True,
        connect_args={'check_same_thread': False}
    )
    db.metadata.create_all(mock_engine)
    with mock_engine.connect() as conn:
        with conn.begin():
            yield conn
    db.metadata.drop_all(mock_engine)
Gotcha #1: Define __hash__ and __eq__ on all Dependencies

The function passed into the key of dependency_overrides needs to have a repeatable hash. Otherwise, the hash of the dependency will be different when when it is resolved by the app process spawned by TestClient.

To illustrate, you can create two python files like so:

#!./deps.py

def say_hello():
    print("Hello World!")
#!./test.py

from multiprocessing import Process

#import deps  # toggle comment and run to compare
def fn():
    import deps
    print('process_hi', hash(deps.say_hello))

if __name__ == '__main__':
    p = Process(target=fn, args=())
    p.start()
    p.join()
    import deps
    print('   base_hi', hash(deps.say_hello))

If you run python test.py, you get:

process_hi 8747143503162
   base_hi 8747143523702

Notice that the function hashes don’t match, even though this is the same function.

I thought this might be related to PYTHONHASHSEED. There’s a great writeup about it on stack overflow if you’re interested. However, it turns out that the default __hash__ method for any function is based on the address of the function pointer in memory.

Because the function was resolved in a sub-process first, that version of the function lives in a different memory address than the one that is later resolved in the main process. Similarly, when an app that is called from TestClient resolves a dependency in a sub-process, the dependency will have a different memory address.

There is a workaround: if you uncomment the import deps line in test.py above, you will get this:

process_hi 8744865430722
   base_hi 8744865430722

As you can see, you can preserve the function hash if you declare it before you open a new process.

Don’t tell your coworkers to import modules in a special way, though. 🤭

Instead, give the dependencies __hash__ and __eq__ methods so they can be found in the dependency_overrides table. You can write the methods yourself or use @dataclass to do it for you (as seen in the Gotcha #1 example above).

Gotcha #2: Unwrap Test Fixtures in Order to Call them Directly

If you use @pytest.fixture() on a function that you would like to reuse as a source of mock objects, make sure to pass the __wrapped__ attribute from the function to the dependency_overrides table. Otherwise, you will get an error saying that you cannot call pytest fixtures directly. __wrapped__ just stores the function as it was before it was decorated as a test fixture.

Gotcha #3: Disable check_same_thread for SQLite While Testing

SQLight cannot handle multithreaded access without risking corruption. However, while testing, the mock database will be created and shared across threads. Luckily, we are only testing one API endpoint at a time. Therefore, we can disable the thread checking failsafe.

check_same_thread (bool) – If True (default), ProgrammingError will be raised if the database connection is used by a thread other than the one that created it. If False, the connection may be accessed in multiple threads; write operations may need to be serialized by the user to avoid data corruption. See threadsafety for more information.

SQLite 3 Docs

Conclusion/Thoughts

That’s everything you need to know to get started with dependency injection and tests in FastAPI!

TestClient is very powerful! Without adding much coupling between your tests and your application, you can now test endpoints with different cookies and headers. This can open you up to test your system for vulnerabilities and ensure that it is secure.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.