Mocking EntityFramework Core DbSet

EntityFramework Core introduces IQueryable<T> extensions for asynchronous DB operations, such as

  • FirstAsync
  • SingleOrDefaultAsync
  • ToArrayAsync

The problem is, when you need to mock an IQueryable in a unit test, and the test subject uses one of these new extensions you will get an error

The provider for the source IQueryable doesn't implement IAsyncQueryProvider. Only providers that implement IAsyncQueryProvider can be used for Entity Framework asynchronous operations
The solution can be found on various StackOverflow posts, but these are usually for EF 6, or an older version of EF Core. So here is a solution that works for EntityFramework Core 3.x

When mocking services for your test subject, you can return any IEnumerable<T> as long as you use the extension AsAsyncQueryable();

 

 


using Microsoft.EntityFrameworkCore.Query.Internal;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

namespace MyApp.Extensions
{
    public static class QueryableExtensions
    {
        public static IQueryable<T> AsAsyncQueryable<T>(this IEnumerable<T> input)
        {
            return new TestAsyncEnumerable<T>(input);
        }

    }
    internal class TestAsyncQueryProvider<TEntity> : IAsyncQueryProvider
    {
        private readonly IQueryProvider Inner;

        public TestAsyncQueryProvider(IQueryProvider inner)
        {
            Inner = inner;
        }

        TResult IAsyncQueryProvider.ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
        {
			Type expectedResultType = typeof(TResult).GetGenericArguments()[0];
			var result =
				typeof(IQueryProvider)
				.GetMethod(
					name: nameof(IQueryProvider.Execute),
					genericParameterCount: 1,
					types: new[] { typeof(Expression) })
				.MakeGenericMethod(expectedResultType)
				.Invoke(this, new[] { expression });

			MethodInfo fromResultMethod = typeof(Task)
				.GetMethod(nameof(Task.FromResult))
				?.MakeGenericMethod(expectedResultType);

			return (TResult)fromResultMethod.Invoke(null, new[] { result });
        }

        IQueryable IQueryProvider.CreateQuery(Expression expression)
        {
            return new TestAsyncEnumerable<TEntity>(expression);
        }

        IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
        {
            return new TestAsyncEnumerable<TElement>(expression);
        }

        object IQueryProvider.Execute(Expression expression)
        {
            return Inner.Execute(expression);
        }

        TResult IQueryProvider.Execute<TResult>(Expression expression)
        {
            return Inner.Execute<TResult>(expression);
        }
    }

    internal class TestAsyncEnumerable<T> : EnumerableQuery<T>, IAsyncEnumerable<T>, IQueryable<T>
    {
        public TestAsyncEnumerable(IEnumerable<T> enumerable)
            : base(enumerable)
        { }

        public TestAsyncEnumerable(Expression expression)
            : base(expression)
        { }

        IAsyncEnumerator<T> IAsyncEnumerable<T>.GetAsyncEnumerator(CancellationToken cancellationToken)
        {
            return new TestAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
        }

        IQueryProvider IQueryable.Provider
        {
            get { return new TestAsyncQueryProvider<T>(this); }
        }
    }

    internal class TestAsyncEnumerator<T> : IAsyncEnumerator<T>
    {
        private readonly IEnumerator<T> Inner;

        public TestAsyncEnumerator(IEnumerator<T> inner)
        {
            Inner = inner;
        }

        T IAsyncEnumerator<T>.Current => Inner.Current;

        ValueTask<bool> IAsyncEnumerator<T>.MoveNextAsync()
        {
            return new ValueTask<bool>(Inner.MoveNext());
        }

        ValueTask IAsyncDisposable.DisposeAsync()
        {
            Inner.Dispose();
            return default;
        }
    }
}

Comments

Popular posts from this blog

Angular - How to create composite controls that work with formGroup/formGroupName and ReactiveForms

Convert absolute path to relative path

Blazor setTimeout