diff --git a/osu.Game.Tests/NonVisual/DifficultyAdjustmentModCombinationsTest.cs b/osu.Game.Tests/NonVisual/DifficultyAdjustmentModCombinationsTest.cs index 760a033aff..de0397dc84 100644 --- a/osu.Game.Tests/NonVisual/DifficultyAdjustmentModCombinationsTest.cs +++ b/osu.Game.Tests/NonVisual/DifficultyAdjustmentModCombinationsTest.cs @@ -94,6 +94,38 @@ namespace osu.Game.Tests.NonVisual Assert.IsTrue(combinations[2] is ModIncompatibleWithAofA); } + [Test] + public void TestMultiMod1() + { + var combinations = new TestLegacyDifficultyCalculator(new ModA(), new MultiMod(new ModB(), new ModC())).CreateDifficultyAdjustmentModCombinations(); + + Assert.AreEqual(4, combinations.Length); + Assert.IsTrue(combinations[0] is ModNoMod); + Assert.IsTrue(combinations[1] is ModA); + Assert.IsTrue(combinations[2] is MultiMod); + Assert.IsTrue(combinations[3] is MultiMod); + + Assert.IsTrue(((MultiMod)combinations[2]).Mods[0] is ModA); + Assert.IsTrue(((MultiMod)combinations[2]).Mods[1] is ModB); + Assert.IsTrue(((MultiMod)combinations[2]).Mods[2] is ModC); + Assert.IsTrue(((MultiMod)combinations[3]).Mods[0] is ModB); + Assert.IsTrue(((MultiMod)combinations[3]).Mods[1] is ModC); + } + + [Test] + public void TestMultiMod2() + { + var combinations = new TestLegacyDifficultyCalculator(new ModA(), new MultiMod(new ModB(), new ModIncompatibleWithA())).CreateDifficultyAdjustmentModCombinations(); + + Assert.AreEqual(3, combinations.Length); + Assert.IsTrue(combinations[0] is ModNoMod); + Assert.IsTrue(combinations[1] is ModA); + Assert.IsTrue(combinations[2] is MultiMod); + + Assert.IsTrue(((MultiMod)combinations[2]).Mods[0] is ModB); + Assert.IsTrue(((MultiMod)combinations[2]).Mods[1] is ModIncompatibleWithA); + } + private class ModA : Mod { public override string Name => nameof(ModA); @@ -112,6 +144,13 @@ namespace osu.Game.Tests.NonVisual public override Type[] IncompatibleMods => new[] { typeof(ModIncompatibleWithAAndB) }; } + private class ModC : Mod + { + public override string Name => nameof(ModC); + public override string Acronym => nameof(ModC); + public override double ScoreMultiplier => 1; + } + private class ModIncompatibleWithA : Mod { public override string Name => $"Incompatible With {nameof(ModA)}"; diff --git a/osu.Game/Rulesets/Difficulty/DifficultyCalculator.cs b/osu.Game/Rulesets/Difficulty/DifficultyCalculator.cs index 1902de5bda..9989c750ee 100644 --- a/osu.Game/Rulesets/Difficulty/DifficultyCalculator.cs +++ b/osu.Game/Rulesets/Difficulty/DifficultyCalculator.cs @@ -107,7 +107,7 @@ namespace osu.Game.Rulesets.Difficulty { return createDifficultyAdjustmentModCombinations(Array.Empty(), DifficultyAdjustmentMods).ToArray(); - IEnumerable createDifficultyAdjustmentModCombinations(IEnumerable currentSet, Mod[] adjustmentSet, int currentSetCount = 0, int adjustmentSetStart = 0) + static IEnumerable createDifficultyAdjustmentModCombinations(IEnumerable currentSet, Mod[] adjustmentSet, int currentSetCount = 0, int adjustmentSetStart = 0) { switch (currentSetCount) { @@ -133,13 +133,36 @@ namespace osu.Game.Rulesets.Difficulty for (int i = adjustmentSetStart; i < adjustmentSet.Length; i++) { var adjustmentMod = adjustmentSet[i]; - if (currentSet.Any(c => c.IncompatibleMods.Any(m => m.IsInstanceOfType(adjustmentMod)))) - continue; - foreach (var combo in createDifficultyAdjustmentModCombinations(currentSet.Append(adjustmentMod), adjustmentSet, currentSetCount + 1, i + 1)) + if (currentSet.Any(c => c.IncompatibleMods.Any(m => m.IsInstanceOfType(adjustmentMod)) + || adjustmentMod.IncompatibleMods.Any(m => m.IsInstanceOfType(c)))) + { + continue; + } + + // Append the new mod. + int newSetCount = currentSetCount; + var newSet = append(currentSet, adjustmentMod, ref newSetCount); + + foreach (var combo in createDifficultyAdjustmentModCombinations(newSet, adjustmentSet, newSetCount, i + 1)) yield return combo; } } + + // Appends a mod to an existing enumerable, returning the result. Recurses for MultiMod. + static IEnumerable append(IEnumerable existing, Mod mod, ref int count) + { + if (mod is MultiMod multi) + { + foreach (var nested in multi.Mods) + existing = append(existing, nested, ref count); + + return existing; + } + + count++; + return existing.Append(mod); + } } ///