diff --git a/osu.Game.Tests/Mods/ModUtilsTest.cs b/osu.Game.Tests/Mods/ModUtilsTest.cs index e4ded602aa..7dcaabca3d 100644 --- a/osu.Game.Tests/Mods/ModUtilsTest.cs +++ b/osu.Game.Tests/Mods/ModUtilsTest.cs @@ -1,9 +1,12 @@ // Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. // See the LICENCE file in the repository root for full licence text. +using System; +using System.Linq; using Moq; using NUnit.Framework; using osu.Game.Rulesets.Mods; +using osu.Game.Rulesets.Osu.Mods; using osu.Game.Utils; namespace osu.Game.Tests.Mods @@ -99,6 +102,53 @@ namespace osu.Game.Tests.Mods Assert.That(ModUtils.CheckAllowed(new[] { mod.Object }, new[] { typeof(Mod) }), Is.False); } + private static readonly object[] invalid_mod_test_scenarios = + { + // incompatible pair. + new object[] + { + new Mod[] { new OsuModDoubleTime(), new OsuModHalfTime() }, + new[] { typeof(OsuModDoubleTime), typeof(OsuModHalfTime) } + }, + // incompatible pair with derived class. + new object[] + { + new Mod[] { new OsuModNightcore(), new OsuModHalfTime() }, + new[] { typeof(OsuModNightcore), typeof(OsuModHalfTime) } + }, + // system mod. + new object[] + { + new Mod[] { new OsuModDoubleTime(), new OsuModTouchDevice() }, + new[] { typeof(OsuModTouchDevice) } + }, + // multi mod. + new object[] + { + new Mod[] { new MultiMod(new OsuModHalfTime()), new OsuModHalfTime() }, + new[] { typeof(MultiMod) } + }, + // valid pair. + new object[] + { + new Mod[] { new OsuModDoubleTime(), new OsuModHardRock() }, + null + } + }; + + [TestCaseSource(nameof(invalid_mod_test_scenarios))] + public void TestInvalidModScenarios(Mod[] inputMods, Type[] expectedInvalid) + { + bool isValid = ModUtils.CheckValidForGameplay(inputMods, out var invalid); + + Assert.That(isValid, Is.EqualTo(expectedInvalid == null)); + + if (isValid) + Assert.IsNull(invalid); + else + Assert.That(invalid?.Select(t => t.GetType()), Is.EquivalentTo(expectedInvalid)); + } + public abstract class CustomMod1 : Mod { } diff --git a/osu.Game/OsuGame.cs b/osu.Game/OsuGame.cs index 5acd6bc73d..a00cd5e6a0 100644 --- a/osu.Game/OsuGame.cs +++ b/osu.Game/OsuGame.cs @@ -468,6 +468,12 @@ namespace osu.Game private void modsChanged(ValueChangedEvent> mods) { updateModDefaults(); + + if (!ModUtils.CheckValidForGameplay(mods.NewValue, out var invalid)) + { + // ensure we always have a valid set of mods. + SelectedMods.Value = mods.NewValue.Except(invalid).ToArray(); + } } private void updateModDefaults() diff --git a/osu.Game/Utils/ModUtils.cs b/osu.Game/Utils/ModUtils.cs index 8ac5bde65a..c12b5a9fd4 100644 --- a/osu.Game/Utils/ModUtils.cs +++ b/osu.Game/Utils/ModUtils.cs @@ -83,6 +83,30 @@ namespace osu.Game.Utils .All(m => allowedSet.Contains(m.GetType())); } + /// + /// Check the provided combination of mods are valid for a local gameplay session. + /// + /// The mods to check. + /// Invalid mods, if any were found. Can be null if all mods were valid. + /// Whether the input mods were all valid. If false, will contain all invalid entries. + public static bool CheckValidForGameplay(IEnumerable mods, out List? invalidMods) + { + mods = mods.ToArray(); + + CheckCompatibleSet(mods, out invalidMods); + + foreach (var mod in mods) + { + if (mod.Type == ModType.System || !mod.HasImplementation || mod is MultiMod) + { + invalidMods ??= new List(); + invalidMods.Add(mod); + } + } + + return invalidMods == null; + } + /// /// Flattens a set of s, returning a new set with all s removed. ///