diff --git a/osu.Game.Tests/Mods/ModValidationTest.cs b/osu.Game.Tests/Mods/ModValidationTest.cs new file mode 100644 index 0000000000..c7a7e242e9 --- /dev/null +++ b/osu.Game.Tests/Mods/ModValidationTest.cs @@ -0,0 +1,64 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using Moq; +using NUnit.Framework; +using osu.Game.Rulesets.Mods; +using osu.Game.Utils; + +namespace osu.Game.Tests.Mods +{ + [TestFixture] + public class ModValidationTest + { + [Test] + public void TestModIsCompatibleByItself() + { + var mod = new Mock(); + Assert.That(ModValidation.CheckCompatible(new[] { mod.Object })); + } + + [Test] + public void TestIncompatibleThroughTopLevel() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + + mod1.Setup(m => m.IncompatibleMods).Returns(new[] { mod2.Object.GetType() }); + + // Test both orderings. + Assert.That(ModValidation.CheckCompatible(new[] { mod1.Object, mod2.Object }), Is.False); + Assert.That(ModValidation.CheckCompatible(new[] { mod2.Object, mod1.Object }), Is.False); + } + + [Test] + public void TestIncompatibleThroughMultiMod() + { + var mod1 = new Mock(); + + // The nested mod. + var mod2 = new Mock(); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { mod1.Object.GetType() }); + + var multiMod = new MultiMod(new MultiMod(mod2.Object)); + + // Test both orderings. + Assert.That(ModValidation.CheckCompatible(new[] { multiMod, mod1.Object }), Is.False); + Assert.That(ModValidation.CheckCompatible(new[] { mod1.Object, multiMod }), Is.False); + } + + [Test] + public void TestAllowedThroughMostDerivedType() + { + var mod = new Mock(); + Assert.That(ModValidation.CheckAllowed(new[] { mod.Object }, new[] { mod.Object.GetType() })); + } + + [Test] + public void TestNotAllowedThroughBaseType() + { + var mod = new Mock(); + Assert.That(ModValidation.CheckAllowed(new[] { mod.Object }, new[] { typeof(Mod) }), Is.False); + } + } +} diff --git a/osu.Game.Tests/osu.Game.Tests.csproj b/osu.Game.Tests/osu.Game.Tests.csproj index c0c0578391..d29ed94b5f 100644 --- a/osu.Game.Tests/osu.Game.Tests.csproj +++ b/osu.Game.Tests/osu.Game.Tests.csproj @@ -7,6 +7,7 @@ + WinExe diff --git a/osu.Game/Utils/ModValidation.cs b/osu.Game/Utils/ModValidation.cs new file mode 100644 index 0000000000..0c4d58ab2e --- /dev/null +++ b/osu.Game/Utils/ModValidation.cs @@ -0,0 +1,116 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using System; +using System.Collections.Generic; +using System.Linq; +using osu.Framework.Extensions.TypeExtensions; +using osu.Game.Rulesets.Mods; + +#nullable enable + +namespace osu.Game.Utils +{ + /// + /// A set of utilities to validate combinations. + /// + public static class ModValidation + { + /// + /// Checks that all s are compatible with each-other, and that all appear within a set of allowed types. + /// + /// + /// The allowed types must contain exact types for the respective s to be allowed. + /// + /// The s to check. + /// The set of allowed types. + /// Whether all s are compatible with each-other and appear in the set of allowed types. + public static bool CheckCompatibleAndAllowed(IEnumerable combination, IEnumerable allowedTypes) + { + // Prevent multiple-enumeration. + var combinationList = combination as ICollection ?? combination.ToArray(); + return CheckCompatible(combinationList) && CheckAllowed(combinationList, allowedTypes); + } + + /// + /// Checks that all s in a combination are compatible with each-other. + /// + /// The combination to check. + /// Whether all s in the combination are compatible with each-other. + public static bool CheckCompatible(IEnumerable combination) + { + var incompatibleTypes = new HashSet(); + var incomingTypes = new HashSet(); + + foreach (var mod in combination.SelectMany(flattenMod)) + { + // Add the new mod incompatibilities, checking whether any match the existing mod types. + foreach (var t in mod.IncompatibleMods) + { + if (incomingTypes.Contains(t)) + return false; + + incompatibleTypes.Add(t); + } + + // Add the new mod types, checking whether any match the incompatible types. + foreach (var t in mod.GetType().EnumerateBaseTypes()) + { + if (incomingTypes.Contains(t)) + return false; + + incomingTypes.Add(t); + } + } + + return true; + } + + /// + /// Checks that all s in a combination appear within a set of allowed types. + /// + /// + /// The set of allowed types must contain exact types for the respective s to be allowed. + /// + /// The combination to check. + /// The set of allowed types. + /// Whether all s in the combination are allowed. + public static bool CheckAllowed(IEnumerable combination, IEnumerable allowedTypes) + { + var allowedSet = new HashSet(allowedTypes); + + return combination.SelectMany(flattenMod) + .All(m => allowedSet.Contains(m.GetType())); + } + + /// + /// Determines whether a is in a set of incompatible types. + /// + /// + /// A can be incompatible through its most-declared type or any of its base types. + /// + /// The to test. + /// The set of incompatible types. + /// Whether the given is incompatible. + private static bool isModIncompatible(Mod mod, ICollection incompatibleTypes) + => flattenMod(mod) + .SelectMany(m => m.GetType().EnumerateBaseTypes()) + .Any(incompatibleTypes.Contains); + + /// + /// Flattens a , returning a set of s in-place of any s. + /// + /// The to flatten. + /// A set of singular "flattened" s + private static IEnumerable flattenMod(Mod mod) + { + if (mod is MultiMod multi) + { + foreach (var m in multi.Mods.SelectMany(flattenMod)) + yield return m; + } + else + yield return mod; + } + } +}