diff --git a/osu.Game.Tests/Mods/ModUtilsTest.cs b/osu.Game.Tests/Mods/ModUtilsTest.cs index 88eee5449c..ff1af88bac 100644 --- a/osu.Game.Tests/Mods/ModUtilsTest.cs +++ b/osu.Game.Tests/Mods/ModUtilsTest.cs @@ -18,50 +18,88 @@ namespace osu.Game.Tests.Mods [Test] public void TestModIsCompatibleByItself() { - var mod = new Mock(); + var mod = new Mock(); Assert.That(ModUtils.CheckCompatibleSet(new[] { mod.Object })); } [Test] public void TestIncompatibleThroughTopLevel() { - var mod1 = new Mock(); - var mod2 = new Mock(); + var mod1 = new Mock(); + var mod2 = new Mock(); mod1.Setup(m => m.IncompatibleMods).Returns(new[] { mod2.Object.GetType() }); // Test both orderings. - Assert.That(ModUtils.CheckCompatibleSet(new[] { mod1.Object, mod2.Object }), Is.False); - Assert.That(ModUtils.CheckCompatibleSet(new[] { mod2.Object, mod1.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.False); } [Test] - public void TestIncompatibleThroughMultiMod() + public void TestMultiModIncompatibleWithTopLevel() { - var mod1 = new Mock(); + var mod1 = new Mock(); // The nested mod. - var mod2 = new Mock(); + var mod2 = new Mock(); mod2.Setup(m => m.IncompatibleMods).Returns(new[] { mod1.Object.GetType() }); var multiMod = new MultiMod(new MultiMod(mod2.Object)); // Test both orderings. - Assert.That(ModUtils.CheckCompatibleSet(new[] { multiMod, mod1.Object }), Is.False); - Assert.That(ModUtils.CheckCompatibleSet(new[] { mod1.Object, multiMod }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { multiMod, mod1.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, multiMod }), Is.False); + } + + [Test] + public void TestTopLevelIncompatibleWithMultiMod() + { + // The nested mod. + var mod1 = new Mock(); + var multiMod = new MultiMod(new MultiMod(mod1.Object)); + + var mod2 = new Mock(); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(CustomMod1) }); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { multiMod, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, multiMod }), Is.False); + } + + [Test] + public void TestCompatibleMods() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.True); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.True); + } + + [Test] + public void TestIncompatibleThroughBaseType() + { + var mod1 = new Mock(); + var mod2 = new Mock(); + mod2.Setup(m => m.IncompatibleMods).Returns(new[] { typeof(Mod) }); + + // Test both orderings. + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod1.Object, mod2.Object }), Is.False); + Assert.That(ModUtils.CheckCompatibleSet(new Mod[] { mod2.Object, mod1.Object }), Is.False); } [Test] public void TestAllowedThroughMostDerivedType() { - var mod = new Mock(); + var mod = new Mock(); Assert.That(ModUtils.CheckAllowed(new[] { mod.Object }, new[] { mod.Object.GetType() })); } [Test] public void TestNotAllowedThroughBaseType() { - var mod = new Mock(); + var mod = new Mock(); Assert.That(ModUtils.CheckAllowed(new[] { mod.Object }, new[] { typeof(Mod) }), Is.False); } @@ -88,5 +126,13 @@ namespace osu.Game.Tests.Mods else Assert.That(invalid?.Select(t => t.GetType()), Is.EquivalentTo(expectedInvalid)); } + + public abstract class CustomMod1 : Mod + { + } + + public abstract class CustomMod2 : Mod + { + } } } diff --git a/osu.Game/Utils/ModUtils.cs b/osu.Game/Utils/ModUtils.cs index 2146abacb6..05a07f0459 100644 --- a/osu.Game/Utils/ModUtils.cs +++ b/osu.Game/Utils/ModUtils.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Linq; -using osu.Framework.Extensions.TypeExtensions; using osu.Game.Rulesets.Mods; #nullable enable @@ -29,7 +28,7 @@ namespace osu.Game.Utils { // Prevent multiple-enumeration. var combinationList = combination as ICollection ?? combination.ToArray(); - return CheckCompatibleSet(combinationList) && CheckAllowed(combinationList, allowedTypes); + return CheckCompatibleSet(combinationList, out _) && CheckAllowed(combinationList, allowedTypes); } /// @@ -38,32 +37,32 @@ namespace osu.Game.Utils /// The combination to check. /// Whether all s in the combination are compatible with each-other. public static bool CheckCompatibleSet(IEnumerable combination) + => CheckCompatibleSet(combination, out _); + + /// + /// Checks that all s in a combination are compatible with each-other. + /// + /// The combination to check. + /// Any invalid mods in the set. + /// Whether all s in the combination are compatible with each-other. + public static bool CheckCompatibleSet(IEnumerable combination, out List? invalidMods) { - var incompatibleTypes = new HashSet(); - var incomingTypes = new HashSet(); + combination = FlattenMods(combination).ToArray(); + invalidMods = null; - foreach (var mod in combination.SelectMany(FlattenMod)) + foreach (var mod in combination) { - // Add the new mod incompatibilities, checking whether any match the existing mod types. - foreach (var t in mod.IncompatibleMods) + foreach (var type in mod.IncompatibleMods) { - if (incomingTypes.Contains(t)) - return false; - - incompatibleTypes.Add(t); - } - - // Add the new mod types, checking whether any match the incompatible types. - foreach (var t in mod.GetType().EnumerateBaseTypes()) - { - if (incomingTypes.Contains(t)) - return false; - - incomingTypes.Add(t); + foreach (var invalid in combination.Where(m => type.IsInstanceOfType(m))) + { + invalidMods ??= new List(); + invalidMods.Add(invalid); + } } } - return true; + return invalidMods == null; } ///