2012-12-21

Decorating Unity extension

First I want to give credit where it is due, the code in this post is based largely on the code I found in this post by Jim Chrisopher.
The original code required me to register my decorators before registering the type I wanted to decorate.  I didn't like this because I prefer to register all the low-level services from within their own assemblies and then have the app (website etc) decorate those types at a higher level.  So I changed the code in the following ways
  1. It now uses the UnityContainer.Configure<> method to decorate types.
  2. The decorators may be registered before or after the registered type, or both, it doesn't matter.
  3. It is possible to register both generic and non-generic types, and both will be used if applicable (e.g. ICommand<> and ICommand<string> would apply to Command<string> but only ICommand<> would apply to Command<int>.)
  4. It works with child containers.
  5. It uses the context.NewBuild method instead of requiring an IUnityContainer reference.
The code is used like this
//Program.cs
using System;
using Microsoft.Practices.Unity;

namespace ConsoleApplication5
{
    class Program
    {
        static void Main(string[] args)
        {
            var container = new UnityContainer();
            container.AddNewExtension<Decorating>();
            container
                .Configure<Decorating>()
                    .Decorate(typeToDecorate: typeof(ICommand<>), decorateWith: typeof(CommandLogger<>))
                .Configure<Decorating>()
                    .Decorate<ICommand<string>, CommandLogger<string>>();

            container.RegisterType(typeof(ICommand<>), typeof(Command<>));

            var intCommand = container.Resolve<ICommand<int>>();
            intCommand.Execute(42);

            var stringCommand = container.Resolve<ICommand<string>>();
            stringCommand.Execute("Hello world");

            Console.ReadLine();
        }
    }
}

//An iterface type to resolve at runtime
using System;

namespace ConsoleApplication5
{
    public interface ICommand<T>
    {
        void Execute(T value);
    }

    public class Command<T> : ICommand<T>
    {
        public Command() 
        {
        }

        public void Execute(T value)
        {
            Console.WriteLine("Command: " + value.ToString());
        }
    }

    public class CommandLogger<T> : ICommand<T>
    {
        readonly ICommand<T> Inner;

        public CommandLogger(ICommand<T> inner)
        {
            this.Inner = inner;
        }

        public void Execute(T value)
        {
            Console.WriteLine("CommandLogger: " + value.ToString());
            Inner.Execute(value);
        }
    }
}



And here it the code which implements the decorating extension. The trick is to allow the normal build up first then set that initial instance as the parameter for creating the first decorator, then that decorator as the parameter for the next decorator to build, and so on.


//DecoratingExtension.cs
using System;
using Microsoft.Practices.Unity;
using Microsoft.Practices.Unity.ObjectBuilder;

namespace ConsoleApplication5
{
    public class Decorating : UnityContainerExtension
    {
        DecoratorTypeRegister Register;

        public Decorating()
            : base() { }

        internal Decorating(IUnityContainer container, DecoratorTypeRegister register)
        {
            container.AddExtension(this);
            this.Register.CopyFrom(register);
        }

        protected override void Initialize()
        {
            this.Register = new DecoratorTypeRegister();
            this.Context.ChildContainerCreated += Context_ChildContainerCreated;
            Context.Strategies.Add(
                    new DecoratingBuildStrategy(Register),
                    UnityBuildStage.PreCreation
            );
        }

        void Context_ChildContainerCreated(object sender, ChildContainerCreatedEventArgs e)
        {
            var decoratingExtension = new Decorating(e.ChildContainer, Register);
        }

        public IUnityContainer Decorate(Type typeToDecorate, Type decorateWith)
        {
            Register.Register(typeToDecorate: typeToDecorate, decorateWith: decorateWith);
            return Container;
        }

        public IUnityContainer Decorate<TTypeToDecorate, TDecorateWith>()
        {
            Decorate(typeof(TTypeToDecorate), typeof(TDecorateWith));
            return Container;
        }
    }
}


//DecoratingBuildStrategy.cs
using System;
using System.Collections.Generic;
using Microsoft.Practices.ObjectBuilder2;
using Microsoft.Practices.Unity;

namespace ConsoleApplication5
{
    internal class DecoratingBuildStrategy : BuilderStrategy
    {
        readonly DecoratorTypeRegister Register;

        internal DecoratingBuildStrategy(DecoratorTypeRegister register)
        {
            this.Register = register;
        }

        public override void PostBuildUp(IBuilderContext context)
        {
            base.PostBuildUp(context);
            Type typeRequested = context.OriginalBuildKey.Type;
            if (!typeRequested.IsInterface)
                return;

            Type typeToBuild = context.BuildKey.Type;

            if (!Register.HasDecorators(typeRequested))
                return;

            var typeStack = new Stack<Type>(Register.GetDecoratorTypes(typeRequested));

            typeStack.ForEach(decoratorType =>
            {
                DependencyOverride dependencyOverride = new DependencyOverride(
                        typeToConstruct: typeRequested,
                        dependencyValue: context.Existing
                    );

                Type actualTypeToBuild = decoratorType;
                if (actualTypeToBuild.IsGenericTypeDefinition)
                {
                    Type[] genericArgumentTypes = context.OriginalBuildKey.Type.GetGenericArguments();
                    actualTypeToBuild = actualTypeToBuild.MakeGenericType(genericArgumentTypes);
                }

                context.AddResolverOverrides(dependencyOverride);
                context.Existing = context.NewBuildUp(new NamedTypeBuildKey(actualTypeToBuild));
            });
        }
    }
}

//DecoratingTypeRegister
using System;
using System.Collections.Generic;

namespace ConsoleApplication5
{
    internal class DecoratorTypeRegister
    {
        readonly Dictionary<Type, List<Type>> DecoratedTypes;

        public DecoratorTypeRegister()
        {
            this.DecoratedTypes = new Dictionary<Type, List<Type>>();
        }

        public void Register(
            Type typeToDecorate,
            Type decorateWith)
        {
            if (typeToDecorate == null)
                throw new ArgumentNullException("TypeToDecorate");
            if (!typeToDecorate.IsInterface)
                throw new ArgumentException("TypeToDecorate must be an interface");
            if (decorateWith == null)
                throw new ArgumentNullException("DecorateWith");

            List<Type> registeredDecoratorTypes;
            if (!DecoratedTypes.TryGetValue(typeToDecorate, out registeredDecoratorTypes))
            {
                registeredDecoratorTypes = new List<Type>();
                DecoratedTypes.Add(typeToDecorate, registeredDecoratorTypes);
            }
            registeredDecoratorTypes.Add(decorateWith);
        }

        public void Register<TTypeToDecorate, TDecorateWith>(
            TTypeToDecorate typeToDecorate,
            TDecorateWith decorateWith)
        {
            Register(typeToDecorate: typeof(TTypeToDecorate), decorateWith: typeof(TDecorateWith));
        }

        public IEnumerable<Type> GetDecoratorTypes(Type typeToDecorate)
        {
            if (typeToDecorate == null)
                throw new ArgumentNullException("TypeToDecorate");
            if (!typeToDecorate.IsInterface)
                throw new ArgumentException("TypeToDecorate must be an interface");

            var result = new List<Type>();
            List<Type> registeredDecoratorTypes;
            if (DecoratedTypes.TryGetValue(typeToDecorate, out registeredDecoratorTypes))
                result.AddRange(registeredDecoratorTypes);
            if (typeToDecorate.IsGenericType)
                if (DecoratedTypes.TryGetValue(typeToDecorate.GetGenericTypeDefinition(), out registeredDecoratorTypes))
                    result.AddRange(registeredDecoratorTypes);
            return result;
        }

        public IEnumerable<Type> GetDecoratorTypes<TTypeToDecorate>()
        {
            return GetDecoratorTypes(typeof(TTypeToDecorate));
        }

        public bool HasDecorators(Type decoratorType)
        {
            return DecoratedTypes.ContainsKey(decoratorType) ||
                (decoratorType.IsGenericType && DecoratedTypes.ContainsKey(decoratorType.GetGenericTypeDefinition()));
        }

        internal void CopyFrom(DecoratorTypeRegister register)
        {
            foreach (var item in register.DecoratedTypes)
            {
                List<Type> decoratorTypeList;
                if (!DecoratedTypes.TryGetValue(item.Key, out decoratorTypeList))
                {
                    decoratorTypeList = new List<Type>();
                    DecoratedTypes.Add(item.Key, decoratorTypeList);
                }
                decoratorTypeList.AddRange(item.Value);
            }
        }
    }
}