Mocking EntityFramework Core DbSet
EntityFramework Core introduces IQueryable<T> extensions for asynchronous DB operations, such as
- FirstAsync
- SingleOrDefaultAsync
- ToArrayAsync
The problem is, when you need to mock an IQueryable in a unit test, and the test subject uses one of these new extensions you will get an error
The provider for the source IQueryable doesn't implement IAsyncQueryProvider. Only providers that implement IAsyncQueryProvider can be used for Entity Framework asynchronous operationsThe solution can be found on various StackOverflow posts, but these are usually for EF 6, or an older version of EF Core. So here is a solution that works for EntityFramework Core 3.x
When mocking services for your test subject, you can return any IEnumerable<T> as long as you use the extension AsAsyncQueryable();
using Microsoft.EntityFrameworkCore.Query.Internal;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
namespace MyApp.Extensions
{
public static class QueryableExtensions
{
public static IQueryable<T> AsAsyncQueryable<T>(this IEnumerable<T> input)
{
return new TestAsyncEnumerable<T>(input);
}
}
internal class TestAsyncQueryProvider<TEntity> : IAsyncQueryProvider
{
private readonly IQueryProvider Inner;
public TestAsyncQueryProvider(IQueryProvider inner)
{
Inner = inner;
}
TResult IAsyncQueryProvider.ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
Type expectedResultType = typeof(TResult).GetGenericArguments()[0];
var result =
typeof(IQueryProvider)
.GetMethod(
name: nameof(IQueryProvider.Execute),
genericParameterCount: 1,
types: new[] { typeof(Expression) })
.MakeGenericMethod(expectedResultType)
.Invoke(this, new[] { expression });
MethodInfo fromResultMethod = typeof(Task)
.GetMethod(nameof(Task.FromResult))
?.MakeGenericMethod(expectedResultType);
return (TResult)fromResultMethod.Invoke(null, new[] { result });
}
IQueryable IQueryProvider.CreateQuery(Expression expression)
{
return new TestAsyncEnumerable<TEntity>(expression);
}
IQueryable<TElement> IQueryProvider.CreateQuery<TElement>(Expression expression)
{
return new TestAsyncEnumerable<TElement>(expression);
}
object IQueryProvider.Execute(Expression expression)
{
return Inner.Execute(expression);
}
TResult IQueryProvider.Execute<TResult>(Expression expression)
{
return Inner.Execute<TResult>(expression);
}
}
internal class TestAsyncEnumerable<T> : EnumerableQuery<T>, IAsyncEnumerable<T>, IQueryable<T>
{
public TestAsyncEnumerable(IEnumerable<T> enumerable)
: base(enumerable)
{ }
public TestAsyncEnumerable(Expression expression)
: base(expression)
{ }
IAsyncEnumerator<T> IAsyncEnumerable<T>.GetAsyncEnumerator(CancellationToken cancellationToken)
{
return new TestAsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
IQueryProvider IQueryable.Provider
{
get { return new TestAsyncQueryProvider<T>(this); }
}
}
internal class TestAsyncEnumerator<T> : IAsyncEnumerator<T>
{
private readonly IEnumerator<T> Inner;
public TestAsyncEnumerator(IEnumerator<T> inner)
{
Inner = inner;
}
T IAsyncEnumerator<T>.Current => Inner.Current;
ValueTask<bool> IAsyncEnumerator<T>.MoveNextAsync()
{
return new ValueTask<bool>(Inner.MoveNext());
}
ValueTask IAsyncDisposable.DisposeAsync()
{
Inner.Dispose();
return default;
}
}
}
Comments