diff --git a/osu.Game.Tests/Mods/ModSettingsEqualityComparison.cs b/osu.Game.Tests/Mods/ModSettingsEqualityComparison.cs new file mode 100644 index 0000000000..7a5789f01a --- /dev/null +++ b/osu.Game.Tests/Mods/ModSettingsEqualityComparison.cs @@ -0,0 +1,36 @@ +// Copyright (c) ppy Pty Ltd . Licensed under the MIT Licence. +// See the LICENCE file in the repository root for full licence text. + +using NUnit.Framework; +using osu.Game.Online.API; +using osu.Game.Rulesets.Osu.Mods; + +namespace osu.Game.Tests.Mods +{ + [TestFixture] + public class ModSettingsEqualityComparison + { + [Test] + public void Test() + { + var mod1 = new OsuModDoubleTime { SpeedChange = { Value = 1.25 } }; + var mod2 = new OsuModDoubleTime { SpeedChange = { Value = 1.26 } }; + var mod3 = new OsuModDoubleTime { SpeedChange = { Value = 1.26 } }; + var apiMod1 = new APIMod(mod1); + var apiMod2 = new APIMod(mod2); + var apiMod3 = new APIMod(mod3); + + Assert.That(mod1, Is.Not.EqualTo(mod2)); + Assert.That(apiMod1, Is.Not.EqualTo(apiMod2)); + + Assert.That(mod2, Is.EqualTo(mod2)); + Assert.That(apiMod2, Is.EqualTo(apiMod2)); + + Assert.That(mod2, Is.EqualTo(mod3)); + Assert.That(apiMod2, Is.EqualTo(apiMod3)); + + Assert.That(mod3, Is.EqualTo(mod2)); + Assert.That(apiMod3, Is.EqualTo(apiMod2)); + } + } +} diff --git a/osu.Game/Online/API/APIMod.cs b/osu.Game/Online/API/APIMod.cs index bff08b0515..4427c82a8b 100644 --- a/osu.Game/Online/API/APIMod.cs +++ b/osu.Game/Online/API/APIMod.cs @@ -11,11 +11,12 @@ using osu.Framework.Bindables; using osu.Game.Configuration; using osu.Game.Rulesets; using osu.Game.Rulesets.Mods; +using osu.Game.Utils; namespace osu.Game.Online.API { [MessagePackObject] - public class APIMod : IMod + public class APIMod : IMod, IEquatable { [JsonProperty("acronym")] [Key(0)] @@ -63,7 +64,16 @@ namespace osu.Game.Online.API return resultMod; } - public bool Equals(IMod other) => Acronym == other?.Acronym; + public bool Equals(IMod other) => other is APIMod them && Equals(them); + + public bool Equals(APIMod other) + { + if (ReferenceEquals(null, other)) return false; + if (ReferenceEquals(this, other)) return true; + + return Acronym == other.Acronym && + Settings.SequenceEqual(other.Settings, ModSettingsEqualityComparer.Default); + } public override string ToString() { @@ -72,5 +82,20 @@ namespace osu.Game.Online.API return $"{Acronym}"; } + + private class ModSettingsEqualityComparer : IEqualityComparer> + { + public static ModSettingsEqualityComparer Default { get; } = new ModSettingsEqualityComparer(); + + public bool Equals(KeyValuePair x, KeyValuePair y) + { + object xValue = ModUtils.GetSettingUnderlyingValue(x.Value); + object yValue = ModUtils.GetSettingUnderlyingValue(y.Value); + + return x.Key == y.Key && EqualityComparer.Default.Equals(xValue, yValue); + } + + public int GetHashCode(KeyValuePair obj) => HashCode.Combine(obj.Key, ModUtils.GetSettingUnderlyingValue(obj.Value)); + } } } diff --git a/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs b/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs index dd854acc32..81ecc74ddb 100644 --- a/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs +++ b/osu.Game/Online/API/ModSettingsDictionaryFormatter.cs @@ -3,11 +3,10 @@ using System.Buffers; using System.Collections.Generic; -using System.Diagnostics; using System.Text; using MessagePack; using MessagePack.Formatters; -using osu.Framework.Bindables; +using osu.Game.Utils; namespace osu.Game.Online.API { @@ -24,36 +23,7 @@ namespace osu.Game.Online.API var stringBytes = new ReadOnlySequence(Encoding.UTF8.GetBytes(kvp.Key)); writer.WriteString(in stringBytes); - switch (kvp.Value) - { - case Bindable d: - primitiveFormatter.Serialize(ref writer, d.Value, options); - break; - - case Bindable i: - primitiveFormatter.Serialize(ref writer, i.Value, options); - break; - - case Bindable f: - primitiveFormatter.Serialize(ref writer, f.Value, options); - break; - - case Bindable b: - primitiveFormatter.Serialize(ref writer, b.Value, options); - break; - - case IBindable u: - // A mod with unknown (e.g. enum) generic type. - var valueMethod = u.GetType().GetProperty(nameof(IBindable.Value)); - Debug.Assert(valueMethod != null); - primitiveFormatter.Serialize(ref writer, valueMethod.GetValue(u), options); - break; - - default: - // fall back for non-bindable cases. - primitiveFormatter.Serialize(ref writer, kvp.Value, options); - break; - } + primitiveFormatter.Serialize(ref writer, ModUtils.GetSettingUnderlyingValue(kvp.Value), options); } } diff --git a/osu.Game/Rulesets/Mods/Mod.cs b/osu.Game/Rulesets/Mods/Mod.cs index 2a11c92223..832a14ee1e 100644 --- a/osu.Game/Rulesets/Mods/Mod.cs +++ b/osu.Game/Rulesets/Mods/Mod.cs @@ -12,6 +12,7 @@ using osu.Framework.Testing; using osu.Game.Configuration; using osu.Game.IO.Serialization; using osu.Game.Rulesets.UI; +using osu.Game.Utils; namespace osu.Game.Rulesets.Mods { @@ -19,7 +20,7 @@ namespace osu.Game.Rulesets.Mods /// The base class for gameplay modifiers. /// [ExcludeFromDynamicCompile] - public abstract class Mod : IMod, IJsonSerializable + public abstract class Mod : IMod, IEquatable, IJsonSerializable { /// /// The name of this mod. @@ -172,7 +173,19 @@ namespace osu.Game.Rulesets.Mods target.Parse(source); } - public bool Equals(IMod other) => GetType() == other?.GetType(); + public bool Equals(IMod other) => other is Mod them && Equals(them); + + public bool Equals(Mod other) + { + if (ReferenceEquals(null, other)) return false; + if (ReferenceEquals(this, other)) return true; + + return GetType() == other.GetType() && + this.GetSettingsSourceProperties().All(pair => + EqualityComparer.Default.Equals( + ModUtils.GetSettingUnderlyingValue(pair.Item2.GetValue(this)), + ModUtils.GetSettingUnderlyingValue(pair.Item2.GetValue(other)))); + } /// /// Reset all custom settings for this mod back to their defaults. diff --git a/osu.Game/Utils/ModUtils.cs b/osu.Game/Utils/ModUtils.cs index c12b5a9fd4..596880f2e7 100644 --- a/osu.Game/Utils/ModUtils.cs +++ b/osu.Game/Utils/ModUtils.cs @@ -3,8 +3,11 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Linq; +using osu.Framework.Bindables; +using osu.Game.Online.API; using osu.Game.Rulesets.Mods; #nullable enable @@ -129,5 +132,38 @@ namespace osu.Game.Utils else yield return mod; } + + /// + /// Returns the underlying value of the given mod setting object. + /// Used in for serialization and equality comparison purposes. + /// + /// The mod setting. + public static object GetSettingUnderlyingValue(object setting) + { + switch (setting) + { + case Bindable d: + return d.Value; + + case Bindable i: + return i.Value; + + case Bindable f: + return f.Value; + + case Bindable b: + return b.Value; + + case IBindable u: + // A mod with unknown (e.g. enum) generic type. + var valueMethod = u.GetType().GetProperty(nameof(IBindable.Value)); + Debug.Assert(valueMethod != null); + return valueMethod.GetValue(u); + + default: + // fall back for non-bindable cases. + return setting; + } + } } }