1
0
mirror of https://github.com/ppy/osu.git synced 2025-01-19 05:02:53 +08:00

Merge pull request #11328 from frenzibyte/mod-using-reference-equality

Fix mods using reference equality unless cast to `IMod`
This commit is contained in:
Dan Balasescu 2021-04-13 15:36:22 +09:00 committed by GitHub
commit 20e84f14e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 116 additions and 36 deletions

View File

@ -0,0 +1,36 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using NUnit.Framework;
using osu.Game.Online.API;
using osu.Game.Rulesets.Osu.Mods;
namespace osu.Game.Tests.Mods
{
[TestFixture]
public class ModSettingsEqualityComparison
{
[Test]
public void Test()
{
var mod1 = new OsuModDoubleTime { SpeedChange = { Value = 1.25 } };
var mod2 = new OsuModDoubleTime { SpeedChange = { Value = 1.26 } };
var mod3 = new OsuModDoubleTime { SpeedChange = { Value = 1.26 } };
var apiMod1 = new APIMod(mod1);
var apiMod2 = new APIMod(mod2);
var apiMod3 = new APIMod(mod3);
Assert.That(mod1, Is.Not.EqualTo(mod2));
Assert.That(apiMod1, Is.Not.EqualTo(apiMod2));
Assert.That(mod2, Is.EqualTo(mod2));
Assert.That(apiMod2, Is.EqualTo(apiMod2));
Assert.That(mod2, Is.EqualTo(mod3));
Assert.That(apiMod2, Is.EqualTo(apiMod3));
Assert.That(mod3, Is.EqualTo(mod2));
Assert.That(apiMod3, Is.EqualTo(apiMod2));
}
}
}

View File

@ -11,11 +11,12 @@ using osu.Framework.Bindables;
using osu.Game.Configuration;
using osu.Game.Rulesets;
using osu.Game.Rulesets.Mods;
using osu.Game.Utils;
namespace osu.Game.Online.API
{
[MessagePackObject]
public class APIMod : IMod
public class APIMod : IMod, IEquatable<APIMod>
{
[JsonProperty("acronym")]
[Key(0)]
@ -63,7 +64,16 @@ namespace osu.Game.Online.API
return resultMod;
}
public bool Equals(IMod other) => Acronym == other?.Acronym;
public bool Equals(IMod other) => other is APIMod them && Equals(them);
public bool Equals(APIMod other)
{
if (ReferenceEquals(null, other)) return false;
if (ReferenceEquals(this, other)) return true;
return Acronym == other.Acronym &&
Settings.SequenceEqual(other.Settings, ModSettingsEqualityComparer.Default);
}
public override string ToString()
{
@ -72,5 +82,20 @@ namespace osu.Game.Online.API
return $"{Acronym}";
}
private class ModSettingsEqualityComparer : IEqualityComparer<KeyValuePair<string, object>>
{
public static ModSettingsEqualityComparer Default { get; } = new ModSettingsEqualityComparer();
public bool Equals(KeyValuePair<string, object> x, KeyValuePair<string, object> y)
{
object xValue = ModUtils.GetSettingUnderlyingValue(x.Value);
object yValue = ModUtils.GetSettingUnderlyingValue(y.Value);
return x.Key == y.Key && EqualityComparer<object>.Default.Equals(xValue, yValue);
}
public int GetHashCode(KeyValuePair<string, object> obj) => HashCode.Combine(obj.Key, ModUtils.GetSettingUnderlyingValue(obj.Value));
}
}
}

View File

@ -3,11 +3,10 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using MessagePack;
using MessagePack.Formatters;
using osu.Framework.Bindables;
using osu.Game.Utils;
namespace osu.Game.Online.API
{
@ -24,36 +23,7 @@ namespace osu.Game.Online.API
var stringBytes = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(kvp.Key));
writer.WriteString(in stringBytes);
switch (kvp.Value)
{
case Bindable<double> d:
primitiveFormatter.Serialize(ref writer, d.Value, options);
break;
case Bindable<int> i:
primitiveFormatter.Serialize(ref writer, i.Value, options);
break;
case Bindable<float> f:
primitiveFormatter.Serialize(ref writer, f.Value, options);
break;
case Bindable<bool> b:
primitiveFormatter.Serialize(ref writer, b.Value, options);
break;
case IBindable u:
// A mod with unknown (e.g. enum) generic type.
var valueMethod = u.GetType().GetProperty(nameof(IBindable<int>.Value));
Debug.Assert(valueMethod != null);
primitiveFormatter.Serialize(ref writer, valueMethod.GetValue(u), options);
break;
default:
// fall back for non-bindable cases.
primitiveFormatter.Serialize(ref writer, kvp.Value, options);
break;
}
primitiveFormatter.Serialize(ref writer, ModUtils.GetSettingUnderlyingValue(kvp.Value), options);
}
}

View File

@ -12,6 +12,7 @@ using osu.Framework.Testing;
using osu.Game.Configuration;
using osu.Game.IO.Serialization;
using osu.Game.Rulesets.UI;
using osu.Game.Utils;
namespace osu.Game.Rulesets.Mods
{
@ -19,7 +20,7 @@ namespace osu.Game.Rulesets.Mods
/// The base class for gameplay modifiers.
/// </summary>
[ExcludeFromDynamicCompile]
public abstract class Mod : IMod, IJsonSerializable
public abstract class Mod : IMod, IEquatable<Mod>, IJsonSerializable
{
/// <summary>
/// The name of this mod.
@ -172,7 +173,19 @@ namespace osu.Game.Rulesets.Mods
target.Parse(source);
}
public bool Equals(IMod other) => GetType() == other?.GetType();
public bool Equals(IMod other) => other is Mod them && Equals(them);
public bool Equals(Mod other)
{
if (ReferenceEquals(null, other)) return false;
if (ReferenceEquals(this, other)) return true;
return GetType() == other.GetType() &&
this.GetSettingsSourceProperties().All(pair =>
EqualityComparer<object>.Default.Equals(
ModUtils.GetSettingUnderlyingValue(pair.Item2.GetValue(this)),
ModUtils.GetSettingUnderlyingValue(pair.Item2.GetValue(other))));
}
/// <summary>
/// Reset all custom settings for this mod back to their defaults.

View File

@ -3,8 +3,11 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using osu.Framework.Bindables;
using osu.Game.Online.API;
using osu.Game.Rulesets.Mods;
#nullable enable
@ -129,5 +132,38 @@ namespace osu.Game.Utils
else
yield return mod;
}
/// <summary>
/// Returns the underlying value of the given mod setting object.
/// Used in <see cref="APIMod"/> for serialization and equality comparison purposes.
/// </summary>
/// <param name="setting">The mod setting.</param>
public static object GetSettingUnderlyingValue(object setting)
{
switch (setting)
{
case Bindable<double> d:
return d.Value;
case Bindable<int> i:
return i.Value;
case Bindable<float> f:
return f.Value;
case Bindable<bool> b:
return b.Value;
case IBindable u:
// A mod with unknown (e.g. enum) generic type.
var valueMethod = u.GetType().GetProperty(nameof(IBindable<int>.Value));
Debug.Assert(valueMethod != null);
return valueMethod.GetValue(u);
default:
// fall back for non-bindable cases.
return setting;
}
}
}
}