Cables connected to a switch

Let's refactor a test: Store and Update OAuth connections

This post is part of my Advent of Code 2022.

Last time, in the Unit Testing 101 series, we refactored a unit test for a method that fed a report of transactions in a payment system. This time, let’s refactor another test. This test is based on a real test I had to refactor in one of my client’s projects.

Before looking at our test, a bit of background. This test belongs to a two-way integration between a Property Management System and a third-party service. Let’s call it: Acme Corporation. To connect one of our properties to Acme, we go throught an OAuth flow.

A bit of background on OAuth flows

To start the OAuth flow, we call an Authorize endpoint in a web browser. Acme prompts us to enter a user and password. Then, they return a verification code. With it, we call a Token endpoint to grab the authentication and refresh tokens. We use the authentication token in a header in future requests.

Apart from the authentication and refresh codes, to make this integration work in both ways, we create some random credentials and send them to Acme. With these credentials, Acme calls some public endpoints on our side.

Here’s the test to refactor

With this background, let’s look at the test we’re going to refactor. This is an integration test that checks that we can create, update and retrieve Acme “connections” in our database.

public class ConnectionRepositoryTests
{
    private const ClientId ClientId = new ClientId(123456);
    
    private static readonly AcmeCredentials AcmeCredentials
      = new AcmeCredentials("AnyAuthenticationToken", "AnyRefreshToken", SomeFutureExpirationDate);
    
    private static readonly AcmeCredentials OtherAcmeCredentials
      = new AcmeCredentials("OtherAuthenticationToken", "OtherRefreshToken", SomeFutureExpirationDate);

    private static readonly AcmeCompany AcmeCompany
      = new AcmeCompany(AcmeCompanyId, AcmeCompanyName);

    private readonly Mock<IAcmeService> _acmeConnectionServiceMock
      = new Mock<IAcmeService>();

    [Fact]
    public async Task GetConnectionAsync_ConnectionUpdated_ReturnsUpdatedConnection()
    {
        var repository = new AcmeConnectionRepository(AnySqlConnection);
        var acmeConnection = new AcmeConnection(ClientId);
        var acmeConnectionId = await repository.CreateAcmeConnectionAsync(acmeConnection);
        acmeConnection.GeneratePkce();
        acmeConnection = AcmeConnection.Load(
            acmeConnectionId,
            ClientId,
            pkce: acmeConnection.Pkce,
            acmeCredentials: AcmeCredentials,
            ourCredentials: OurCredentials.GenerateCredentials(ClientId));
        await repository.UpdateAcmeConnectionAsync(acmeConnection);

        var connectionFromDb = await repository.GetAcmeConnectionAsync(ClientId);
        acmeConnection = AcmeConnection.Load(
            acmeConnectionId,
            ClientId,
            AcmeCompany,
            connectionFromDb!.Pkce);
        acmeConnection.GeneratePkce();
        acmeConnection = AcmeConnection.Load(
            acmeConnectionId,
            ClientId,
            AcmeCompany,
            acmeConnection.Pkce,
            connectionFromDb.AcmeCredentials,
            connectionFromDb.OurCredentials);
        acmeConnection.UpdateAcmeCredentials(OtherAcmeCredentials);
        await acmeConnection.SetOurCredentialsAsync(_acmeConnectionServiceMock.Object);
        await repository.UpdateAcmeConnectionAsync(acmeConnection);
        var updatedConnectionFromDb = await repository.GetAcmeConnectionAsync(new ClientId(ClientId));
        acmeConnection = AcmeConnection.Load(
            acmeConnectionId,
            ClientId,
            AcmeCompany,
            Pkce.Load(acmeConnection.Pkce!.Id!,
                      acmeConnection.Pkce.CodeVerifier,
                      updatedConnectionFromDb.Pkce!.CreatedDate,
                      updatedConnectionFromDb.Pkce.UpdatedDate),
            AcmeCredentials.Load(acmeConnection.AcmeCredentials!.Id!,
                                acmeConnection.AcmeCredentials.RefreshToken,
                                acmeConnection.AcmeCredentials.AccessToken,
                                acmeConnection.AcmeCredentials.AccessTokenExpiration,
                                updatedConnectionFromDb.AcmeCredentials!.CreatedDate,
                                updatedConnectionFromDb.AcmeCredentials.UpdatedDate),
            OurCredentials.Load(acmeConnection.OurCredentials!.Id!,
                                acmeConnection.OurCredentials.Username,
                                acmeConnection.OurCredentials.Password,
                                updatedConnectionFromDb.OurCredentials!.CreatedDate,
                                updatedConnectionFromDb.OurCredentials.UpdatedDate));

        Assert.NotNull(connectionFromDb);
        Assert.NotNull(updatedConnectionFromDb);
        Assert.Equal(acmeConnectionId, connectionFromDb!.Id);
        Assert.Equal(acmeConnectionId, updatedConnectionFromDb!.Id);
        Assert.Equal(acmeConnection, updatedConnectionFromDb);
        Assert.NotEqual(acmeConnection, connectionFromDb);
    }
}

Yes, that’s the real test. “Some names have been changed to protect the innocent.” Can you take a look and identify what our test does?

To be fair, here’s the AcmeConnection class with the signature of Load() and other methods,

public record LightspeedConnection(PmsPropertyId PmsPropertyId)
{
    public static AcmeConnection Load(
        AcmeConnectionId dbId,
        ClientId clientId,
        AcmeCompany? acmeCompany = null,
        Pkce? pkce = null,
        AcmeCredentials? acmeCredentials = null,
        OurCredentials? ourCredentials = null)
    {
        // Create a new AcmeConnection from all the parameters
        // Beep, beep, boop...
    }

    // A bunch of methods to update the AcmeConnection state
    public void GeneratePkce() { /* ... */ }

    public void UpdateAcmeCompany(AcmeCompany company) { /* ... */ }

    public void UpdateAcmeCredentials(AcmeCredentials credentials) { /* ... */ }

    public void SetOurCredentialsAsync(IAcmeService service) { /* ... */ }
}

The Pkce object corresponds to two security codes we exchange in the OAuth flow. For more details, see Dropbox guide on PKCE.

A electronic panel with lots of cables
Photo by John Barkiple on Unsplash

What’s wrong?

Did you spot what our test does? Don’t worry. It took me some time to get what this test does, even though I was familiar with that codebase.

That test is full of noise and hard to follow. It abuses the acmeConnection variable. It keeps reading and assigning connections to it.

Behind all that noise, our test creates a new connection and stores it. Then, it retrieves, mutates, and updates the same connection. And in the last step, it recreates another one from all the input values to use it in the Assert part.

Let’s see the test again, annotated this time,

[Fact]
public async Task GetConnectionAsync_ConnectionUpdated_ReturnsUpdatedConnection()
{
    var repository = new AcmeConnectionRepository(AnySqlConnection);
    var acmeConnection = new AcmeConnection(ClientId);
    var acmeConnectionId = await repository.CreateAcmeConnectionAsync(acmeConnection);
    // 1. Create connection                 ^^^^^
    
    acmeConnection.GeneratePkce();
    acmeConnection = AcmeConnection.Load(
        acmeConnectionId,
        ClientId,
        pkce: acmeConnection.Pkce,
        acmeCredentials: AcmeCredentials,
        ourCredentials: OurCredentials.GenerateCredentials(ClientId));
    //  ^^^^^
    // 2. Change both credentials
    await repository.UpdateAcmeConnectionAsync(acmeConnection);

    var connectionFromDb = await repository.GetAcmeConnectionAsync(ClientId);
    //                                      ^^^^^
    // 3. Retrieve the newly created connection
    acmeConnection = AcmeConnection.Load(
        acmeConnectionId,
        ClientId,
        AcmeCompany,
        connectionFromDb!.Pkce);
    //  ^^^^^
    acmeConnection.GeneratePkce();
    //             ^^^^
    acmeConnection = AcmeConnection.Load(
        acmeConnectionId,
        ClientId,
        AcmeCompany,
        acmeConnection.Pkce,
        connectionFromDb.AcmeCredentials,
        connectionFromDb.OurCredentials);
    acmeConnection.UpdateAcmeCredentials(OtherAcmeCredentials);
    //             ^^^^^
    await acmeConnection.SetOurCredentialsAsync(_acmeConnectionServiceMock.Object);
    //                   ^^^^^
    // 4. Change Acme company and both credentials again
    await repository.UpdateAcmeConnectionAsync(acmeConnection);
    //               ^^^^^
    // 5. Update
    
    var updatedConnectionFromDb = await repository.GetAcmeConnectionAsync(new ClientId(ClientId));
    acmeConnection = AcmeConnection.Load(
    //                              ^^^^^
        acmeConnectionId,
        ClientId,
        AcmeCompany,
        Pkce.Load(acmeConnection.Pkce!.Id!,
                  acmeConnection.Pkce.CodeVerifier,
                  updatedConnectionFromDb.Pkce!.CreatedDate,
                  updatedConnectionFromDb.Pkce.UpdatedDate),
        AcmeCredentials.Load(acmeConnection.AcmeCredentials!.Id!,
                            acmeConnection.AcmeCredentials.RefreshToken,
                            acmeConnection.AcmeCredentials.AccessToken,
                            acmeConnection.AcmeCredentials.AccessTokenExpiration,
                            updatedConnectionFromDb.AcmeCredentials!.CreatedDate,
                            updatedConnectionFromDb.AcmeCredentials.UpdatedDate),
        OurCredentials.Load(acmeConnection.OurCredentials!.Id!,
                            acmeConnection.OurCredentials.Username,
                            acmeConnection.OurCredentials.Password,
                            updatedConnectionFromDb.OurCredentials!.CreatedDate,
                            updatedConnectionFromDb.OurCredentials.UpdatedDate));

    Assert.NotNull(connectionFromDb);
    Assert.NotNull(updatedConnectionFromDb);
    Assert.Equal(acmeConnectionId, connectionFromDb!.Id);
    Assert.Equal(acmeConnectionId, updatedConnectionFromDb!.Id);
    Assert.Equal(acmeConnection, updatedConnectionFromDb);
    Assert.NotEqual(acmeConnection, connectionFromDb);
}

Also, this test keeps using the Load() method, even though the AcmeConnection class has some methods to update its own state.

Step 1. Use the same code as the Production code

Write integration tests using the same code as the production code.

Let’s write our test in terms of our business methods instead of using the Load() everywhere.

[Fact]
public async Task GetConnectionAsync_ConnectionUpdated_ReturnsUpdatedConnection()
{
    var repository = new AcmeConnectionRepository(AnySqlConnection);
    var acmeConnection = new AcmeConnection(ClientId);
    var acmeConnectionId = await repository.CreateAcmeConnectionAsync(acmeConnection);
    // 1. Create connection                 ^^^^^
    
    acmeConnection = await repository.GetAcmeConnectionAsync(ClientId);
    acmeConnection.GeneratePkce();
    //             ^^^^^
    await repository.UpdateAcmeConnectionAsync(acmeConnection);
    //               ^^^^^
    // 2. Update pkce

    acmeConnection = await repository.GetAcmeConnectionAsync(ClientId);
    acmeConnection.UpdateAcmeCompany(AcmeCompany);
    //             ^^^^^
    acmeConnection.UpdateAcmeCredentials(OtherAcmeCredentials);
    //             ^^^^^
    await acmeConnection.SetOurCredentialsAsync(_acmeConnectionServiceMock.Object);
    //                   ^^^^^
    await repository.UpdateAcmeConnectionAsync(acmeConnection);
    //               ^^^^^
    // 3. Update company and credentials
    
    var updatedConnectionFromDb = await repository.GetAcmeConnectionAsync(ClientId);

    Assert.NotNull(updatedConnectionFromDb);
    Assert.Equal(acmeConnectionId, updatedConnectionFromDb!.Id);
    Assert.Equal(acmeConnection.Pkce, updatedConnectionFromDb.Pkce);
    Assert.Equal(acmeConnection.AcmeCompany, updatedConnectionFromDb.AcmeCompany);
    Assert.NotNull(updatedConnectionFromDb.AcmeCredentials);
    Assert.NotNull(updatedConnectionFromDb.OurCredentials);
}

Notice, we stopped using the Load() method. We rewrote the test using the methods from the AcmeConnection class like UpdateAcmeCredentials, SetOurCredentialsAsync, and others.

Also, we separated the test into blocks. In each block, we retrieved the acmeConnection, mutated it with its own methods, and called UpdateAcmeConnectionAsync(). Cleaner!- I’d say.

We removed the last Load() call. We didn’t need to assert if the last retrieved object was exactly the same as the recreated version. Instead, we checked that the updated connection had the same value objects.

Step 2. Use descriptive variables

For the next step, let’s stop abusing the same acmeConnection variable and create more descriptive variables for every step.

[Fact]
public async Task GetConnectionAsync_ConnectionUpdated_ReturnsUpdatedConnection()
{
    var repository = new AcmeConnectionRepository(AnySqlConnection);
    var acmeConnection = new AcmeConnection(ClientId);
    var acmeConnectionId = await repository.CreateAcmeConnectionAsync(acmeConnection);
    
    var newlyCreated = await repository.GetAcmeConnectionAsync(ClientId);
    //  ^^^^^
    newlyCreated.GeneratePkce();
    await repository.UpdateAcmeConnectionAsync(newlyCreated);

    var pkceUpdated = await repository.GetAcmeConnectionAsync(ClientId);
    //  ^^^^^
    pkceUpdated.UpdateAcmeCompany(AcmeCompany);
    pkceUpdated.UpdateAcmeCredentials(OtherAcmeCredentials);
    await pkceUpdated.SetOurCredentialsAsync(_acmeConnectionServiceMock.Object);
    await repository.UpdateAcmeConnectionAsync(pkceUpdated);
    
    var updated = await repository.GetAcmeConnectionAsync(ClientId);
    //  ^^^^^

    Assert.NotNull(updated);
    Assert.Equal(acmeConnectionId, updated!.Id);
    Assert.Equal(pkceUpdated.Pkce, updated.Pkce);
    Assert.Equal(pkceUpdated.AcmeCompany, updated.AcmeCompany);
    Assert.NotNull(updated.AcmeCredentials);
    Assert.NotNull(updated.OurCredentials);
}

With these variables names is easier to follow what our test does.

An alternative solution with Factory methods

We were lucky there were a lot of methods on the AcmeConnection class to mutate and update it in the tests. If we didn’t have those methods, we could create one “clone” method for every property we needed to mutate.

For example,

public static class AcmeConnectionExtensions
{
    public static AcmeConnection CredentialsFrom(
        this LightspeedConnection self,
        AcmeCredentials acmeCredentials,
        OurCredentials ourCredentials)
    {
        // Copy self and change AcmeCredentials and OurCredentials
    }

    public static AcmeConnection AcmeCompanyFrom(
        this LightspeedConnection self,
        AcmeCompany acmeCompany)

    {
        // Copy self and change the AcmeCompany
    }
}

We can create an initial AcmeConnection and clone it with our helper methods to reduce all boilerplate in our original test.

Voilà! That was a long refactoring session. There are two things we can take away from this refactoring. First, we should strive for readability in our tests. We should make our test even more readable than our production code. Can anyone spot what one of our tests does in 30 seconds? That’s a readable test. Second, we should always write our tests using the same code as our production code. We shouldn’t write production code to only use it inside our unit tests. That Load() method was a backdoor to build objects when we should have used class constructors and methods to mutate its state.

To read more content about unit testing, check how to write tests for HttpClient, how to test an ASP.NET filter, and how to write tests for logging messages. Don’t miss my Unit Testing 101 series where I cover from naming conventions to best practices.

Happy testing!