Photo by Alexander Grey
Introduction
In more or less all projects I worked on over the years, some kind of Enumeration class has always been used.
Enumerations (or enum types for short) are a thin language wrapper around an integral type. You might want to limit their use to when you are storing one value from a closed set of values. Classification based on sizes (small, medium, large) is a good example. Using enums for control flow or more robust abstractions can be a code smell. This type of usage leads to fragile code with many control flow statements checking values of the enum.
Instead, you can create Enumeration classes that enable all the rich features of an object-oriented language.
Example implementation of an Enumeration:
public class CardType : Enumeration
{
public static CardType Amex = new(1, nameof(Amex));
public static CardType Visa = new(2, nameof(Visa));
public static CardType MasterCard = new(3, nameof(MasterCard));
public CardType(int id, string name)
: base(id, name)
{
}
}
This week, I noticed that my current project also uses Enumeration classes. They've based their implementation on the following class provided by Microsoft.
Implementations
Microsoft
public abstract class Enumeration : IComparable
{
public string Name { get; private set; }
public int Id { get; private set; }
protected Enumeration(int id, string name) => (Id, Name) = (id, name);
public override string ToString() => Name;
public static IEnumerable<T> GetAll<T>() where T : Enumeration =>
typeof(T).GetFields(BindingFlags.Public |
BindingFlags.Static |
BindingFlags.DeclaredOnly)
.Select(f => f.GetValue(null))
.Cast<T>();
public override bool Equals(object obj)
{
if (obj is not Enumeration otherValue)
{
return false;
}
var typeMatches = GetType().Equals(obj.GetType());
var valueMatches = Id.Equals(otherValue.Id);
return typeMatches && valueMatches;
}
public override int GetHashCode() => Id.GetHashCode();
public static int AbsoluteDifference(Enumeration firstValue, Enumeration secondValue)
{
var absoluteDifference = Math.Abs(firstValue.Id - secondValue.Id);
return absoluteDifference;
}
public static T FromValue<T>(int value) where T : Enumeration
{
var matchingItem = Parse<T, int>(value, "value", item => item.Id == value);
return matchingItem;
}
public static T FromDisplayName<T>(string displayName) where T : Enumeration
{
var matchingItem = Parse<T, string>(displayName, "display name", item => item.Name == displayName);
return matchingItem;
}
private static T Parse<T, K>(K value, string description, Func<T, bool> predicate) where T : Enumeration
{
var matchingItem = GetAll<T>().FirstOrDefault(predicate);
if (matchingItem == null)
throw new InvalidOperationException($"'{value}' is not a valid {description} in {typeof(T)}");
return matchingItem;
}
public int CompareTo(object other) => Id.CompareTo(((Enumeration)other).Id);
}
One thing that caught my eye was the GetAll method. It uses reflection to return all the declared fields of a class (All CardTypes in the above example). It's not as bad as it looks though, you won't pay the reflection price on every call to the GetAll
method. This is because of the caching that the runtime does when it comes to reflection and metadata. You can read more about that here.
Our implementation at work
This is the version I found in my current project at work. It looks more or less the same as the Microsoft implementation except for a bunch of utility methods.
public abstract class Enumeration : IComparable
{
protected Enumeration() { Description = string.Empty; }
protected Enumeration(int value, string description)
{
if (value < -1 || value == 0)
{
throw new Exception(
$"Invalid value: {value}. Please use -1 to represent a null value and positive values otherwise.");
}
Value = value;
if (description.Length > MaxDescriptionLength)
{
throw new Exception($"Display name can be max {MaxDescriptionLength} characters");
}
Description = description;
}
public int Value { get; }
public string Description { get; }
public string GetName<T>() where T : Enumeration, new()
{
var type = typeof(T);
var fields = type.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly);
foreach (var info in fields)
{
var instance = new T();
var locatedValue = info.GetValue(instance) as T;
if (locatedValue?.Value == Value)
{
return info.Name;
}
}
throw new Exception($"The enumeration value {Value} could not be found");
}
public override string ToString()
{
return Description;
}
public static IEnumerable<T> GetAll<T>() where T : Enumeration
{
var type = typeof(T);
var fields = type.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly);
foreach (var info in fields)
{
if (info.GetValue(null) is T locatedValue)
{
yield return locatedValue;
}
}
}
public override bool Equals(object? obj)
{
if (!(obj is Enumeration otherValue))
{
return false;
}
var typeMatches = GetType() == obj.GetType();
var valueMatches = Value.Equals(otherValue.Value);
return typeMatches && valueMatches;
}
public override int GetHashCode()
{
return HashCode.Combine(Value);
}
public static bool Exists<T>(int value) where T : Enumeration, new()
{
var matchingItem = GetMatchingItem<T>(item => item.Value == value);
return matchingItem != null;
}
public static bool Exists<T>(string description) where T : Enumeration, new()
{
var matchingItem = GetMatchingItem<T>(item => item.Description == description);
return matchingItem != null;
}
public static T FromValue<T>(int value) where T : Enumeration
{
var matchingItem = Parse<T, int>(value, "value", item => item.Value == value);
return matchingItem;
}
public static Result<T> Parse<T>(int value) where T : Enumeration
{
var matchingItem = GetMatchingItem<T>(item => item.Value == value);
if (matchingItem == null)
{
return Result.Failure<T>(
new Error(ErrorTypes.Validation, $"'{value}' is not a valid value in {typeof(T)}"));
}
return Result.Success(matchingItem);
}
public static T FromDescription<T>(string description) where T : Enumeration
{
var matchingItem = Parse<T, string>(description, "description", item => item.Description == description);
return matchingItem;
}
private static T Parse<T, TK>(TK value, string description, Func<T, bool> predicate) where T : Enumeration
{
var matchingItem = GetMatchingItem(predicate);
if (matchingItem == null)
{
var message = $"'{value}' is not a valid {description} in {typeof(T)}";
throw new Exception(message);
}
return matchingItem;
}
public static T? GetMatchingItem<T>(Func<T, bool> predicate) where T : Enumeration
{
return GetAll<T>().FirstOrDefault(predicate);
}
public int CompareTo(object? other)
{
if (!(other is Enumeration e))
{
throw new ArgumentException("obj is not the same type as this instance");
}
return Value.CompareTo(e.Value);
}
}
This implementation has the same GetAll
problem, it uses reflection on every call.
Record implementation
By using a record, we get the equality checks for free since records implements the IEquatable<T> interface automatically. Sadly, records doesn't implement the IComparable interface, so we need to do that ourselves.
My idea was to make the Enumeration class generic and then use T
in a static constructor and use reflection to get all instances. I would then cache all instances in two dictionaries, one using Value as key and the other one where the DisplayName would be the key. This will allow for really fast lookups, close to O(1).
If we don't wrap the initialization of our dictionaries in a Lazy class, we will get runtime errors. This is because the static constructor will run before the constructor of our actual implementation. When that happens, we will not get any values from the reflection call, everything will be null. Not good :).
public abstract record Enumeration<T> : IComparable<T> where T : Enumeration<T>
{
private static readonly Lazy<Dictionary<int, T>> AllItems;
private static readonly Lazy<Dictionary<string, T>> AllItemsByName;
static Enumeration()
{
AllItems = new Lazy<Dictionary<int, T>>(() =>
{
return typeof(T)
.GetFields(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(x => x.FieldType == typeof(T))
.Select(x => x.GetValue(null))
.Cast<T>()
.ToDictionary(x => x.Value, x => x);
});
AllItemsByName = new Lazy<Dictionary<string, T>>(() =>
{
var items = new Dictionary<string, T>(AllItems.Value.Count);
foreach (var item in AllItems.Value)
{
if (!items.TryAdd(item.Value.DisplayName, item.Value))
{
throw new Exception(
$"DisplayName needs to be unique. '{item.Value.DisplayName}' already exists");
}
}
return items;
});
}
protected Enumeration(int value, string displayName)
{
Value = value;
DisplayName = displayName;
}
public int Value { get; }
public string DisplayName { get; }
public override string ToString() => DisplayName;
public static IEnumerable<T> GetAll()
{
return AllItems.Value.Values;
}
public static int AbsoluteDifference(Enumeration<T> firstValue, Enumeration<T> secondValue)
{
return Math.Abs(firstValue.Value - secondValue.Value);
}
public static T FromValue(int value)
{
if (AllItems.Value.TryGetValue(value, out var matchingItem))
{
return matchingItem;
}
throw new InvalidOperationException($"'{value}' is not a valid value in {typeof(T)}");
}
public static T FromDisplayName(string displayName)
{
if (AllItemsByName.Value.TryGetValue(displayName, out var matchingItem))
{
return matchingItem;
}
throw new InvalidOperationException($"'{displayName}' is not a valid display name in {typeof(T)}");
}
public int CompareTo(T? other) => Value.CompareTo(other!.Value);
}
If you don't want to use records, it is of course possible to use the above approach with a class instead.
Let's do some benchmarks!
Benchmarks
BenchmarkDotNet=v0.13.2, OS=Windows 11 (10.0.22621.521)
11th Gen Intel Core i7-11370H 3.30GHz, 1 CPU, 8 logical and 4 physical cores
.NET SDK=7.0.100-rc.1.22431.12
[Host] : .NET 7.0.0 (7.0.22.42610), X64 RyuJIT AVX2
.NET 7.0 : .NET 7.0.0 (7.0.22.42610), X64 RyuJIT AVX2
Job=.NET 7.0 Runtime=.NET 7.0
| Method | Categories | Mean | Error | StdDev | Median | Ratio | RatioSD | Gen0 | Allocated | Alloc Ratio |
|-------------------- |----------- |----------:|---------:|---------:|----------:|------:|--------:|-------:|----------:|------------:|
| FromName_Microsoft | FromName | 148.71 ns | 2.985 ns | 5.228 ns | 146.31 ns | 1.00 | 0.00 | 0.0381 | 240 B | 1.00 |
| FromName_Ours | FromName | 108.93 ns | 1.214 ns | 1.077 ns | 108.70 ns | 0.72 | 0.03 | 0.0381 | 240 B | 1.00 |
| FromName_Record | FromName | 17.84 ns | 0.162 ns | 0.143 ns | 17.82 ns | 0.12 | 0.00 | - | - | 0.00 |
| | | | | | | | | | | |
| FromValue_Microsoft | FromValue | 192.86 ns | 3.731 ns | 5.585 ns | 192.47 ns | 1.00 | 0.00 | 0.0381 | 240 B | 1.00 |
| FromValue_Ours | FromValue | 148.31 ns | 1.598 ns | 1.335 ns | 148.00 ns | 0.77 | 0.02 | 0.0381 | 240 B | 1.00 |
| FromValue_Record | FromValue | 10.89 ns | 0.087 ns | 0.081 ns | 10.88 ns | 0.06 | 0.00 | - | - | 0.00 |
| | | | | | | | | | | |
| GetAll_Microsoft | GetAll | 240.63 ns | 4.043 ns | 3.782 ns | 239.61 ns | 1.00 | 0.00 | 0.0381 | 240 B | 1.00 |
| GetAll_Ours | GetAll | 186.35 ns | 2.347 ns | 2.196 ns | 185.49 ns | 0.77 | 0.01 | 0.0381 | 240 B | 1.00 |
| GetAll_Record | GetAll | 23.43 ns | 0.324 ns | 0.287 ns | 23.32 ns | 0.10 | 0.00 | 0.0127 | 80 B | 0.33 |
Great success! The record implementation is faster and allocates less than the other implementations. This is because we are using the lazy generic approach that allows us to cache all items so that we only pay the reflection cost once. It hasn't anything to do with using records, really. The only thing that the record helps us with is that we need to write less code (Equals etc). :)
So once again, the "generic approach" that we're using in the record implementation will work just as good in a "regular class implementation".
Another thing that greatly improves the performance is that we are using a Dictionary when doing the FromName and FromValue lookup.