mirror of
https://github.com/ppy/osu.git
synced 2024-12-05 09:42:54 +08:00
Implement quartic curve fitting miss penalty
This commit is contained in:
parent
fd53423121
commit
ef0797ad30
@ -19,6 +19,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
[JsonProperty("aim_difficulty")]
|
||||
public double AimDifficulty { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The difficulty corresponding to the aim skill.
|
||||
/// </summary>
|
||||
[JsonProperty("aim_penalty_constants")]
|
||||
public (double, double, double) AimPenaltyConstants { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The difficulty corresponding to the speed skill.
|
||||
/// </summary>
|
||||
|
@ -37,6 +37,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
return new OsuDifficultyAttributes { Mods = mods };
|
||||
|
||||
double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier;
|
||||
(double, double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients();
|
||||
double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier;
|
||||
double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier;
|
||||
double speedNotes = ((Speed)skills[2]).RelevantNoteCount();
|
||||
@ -97,6 +98,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
StarRating = starRating,
|
||||
Mods = mods,
|
||||
AimDifficulty = aimRating,
|
||||
AimPenaltyConstants = aimPenaltyConstants,
|
||||
SpeedDifficulty = speedRating,
|
||||
SpeedNoteCount = speedNotes,
|
||||
FlashlightDifficulty = flashlightRating,
|
||||
|
@ -5,6 +5,7 @@ using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using osu.Game.Rulesets.Difficulty;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Utils;
|
||||
using osu.Game.Rulesets.Osu.Mods;
|
||||
using osu.Game.Rulesets.Scoring;
|
||||
using osu.Game.Scoring;
|
||||
@ -93,9 +94,16 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
|
||||
// Penalize misses by assessing # of misses relative to the total # of objects. Default a 3% reduction for any # of misses.
|
||||
if (effectiveMissCount > 0)
|
||||
aimValue *= 0.97 * Math.Pow(1 - Math.Pow(effectiveMissCount / totalHits, 0.775), effectiveMissCount);
|
||||
{
|
||||
double a = attributes.AimPenaltyConstants.Item1;
|
||||
double b = attributes.AimPenaltyConstants.Item2;
|
||||
double c = attributes.AimPenaltyConstants.Item3;
|
||||
double d = Math.Log(totalHits + 1) - a - b - c;
|
||||
|
||||
aimValue *= getComboScalingFactor(attributes);
|
||||
double penalty = Math.Pow(1 - RootFinding.FindRootExpand(x => Math.Exp(a * x * x * x * x + b * x * x * x + c * x * x + d * x) - 1 - effectiveMissCount, 0, 1), 1.5);
|
||||
|
||||
aimValue *= penalty;
|
||||
}
|
||||
|
||||
double approachRateFactor = 0.0;
|
||||
if (attributes.ApproachRate > 10.33)
|
||||
|
@ -46,5 +46,37 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
|
||||
return currentStrain;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// The coefficients of a quartic fitted to the miss counts at each skill level.
|
||||
/// </summary>
|
||||
/// <returns>The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator.</returns>
|
||||
public (double, double, double) GetMissCountCoefficients()
|
||||
{
|
||||
const int count = 21;
|
||||
const double penalty_per_misscount = 1.0 / (count - 1);
|
||||
|
||||
double fcSkill = DifficultyValue();
|
||||
|
||||
double[] misscounts = new double[count];
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
if (i == 0)
|
||||
{
|
||||
misscounts[i] = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
double penalizedSkill = fcSkill - fcSkill * penalty_per_misscount * i;
|
||||
|
||||
// Save misscounts as log form to give higher weight to lower values. Add 1 so that the lowest misscounts remain above 0.
|
||||
misscounts[i] = Math.Log(GetMissCountAtSkill(penalizedSkill) + 1);
|
||||
}
|
||||
|
||||
double[] constants = FitMissCountPoints.GetPolynomialCoefficients(misscounts);
|
||||
|
||||
return (constants[0], constants[1], constants[2]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,8 +22,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
/// A higher value rewards short, high difficulty sections, whereas a lower value rewards consistent, lower difficulty.
|
||||
protected abstract double FcProbability { get; }
|
||||
|
||||
private const int bin_count = 32;
|
||||
|
||||
private readonly List<double> difficulties = new List<double>();
|
||||
|
||||
/// <summary>
|
||||
@ -43,7 +41,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
double maxDiff = difficulties.Max();
|
||||
if (maxDiff <= 1e-10) return 0;
|
||||
|
||||
var bins = Bin.CreateBins(difficulties, bin_count);
|
||||
var bins = Bin.CreateBins(difficulties);
|
||||
|
||||
const double lower_bound = 0;
|
||||
double upperBoundEstimate = 3.0 * maxDiff;
|
||||
@ -93,7 +91,34 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
if (difficulties.Count == 0)
|
||||
return 0;
|
||||
|
||||
return difficulties.Count < 2 * bin_count ? difficultyValueExact() : difficultyValueBinned();
|
||||
return difficulties.Count < 64 ? difficultyValueExact() : difficultyValueBinned();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Find the lowest misscount that a player with the provided <paramref name="skill"/> would have a 2% chance of achieving.
|
||||
/// </summary>
|
||||
public double GetMissCountAtSkill(double skill)
|
||||
{
|
||||
double maxDiff = difficulties.Max();
|
||||
|
||||
if (maxDiff == 0)
|
||||
return 0;
|
||||
if (skill <= 0)
|
||||
return difficulties.Count;
|
||||
|
||||
PoissonBinomial poiBin;
|
||||
|
||||
if (difficulties.Count > 64)
|
||||
{
|
||||
var bins = Bin.CreateBins(difficulties);
|
||||
poiBin = new PoissonBinomial(bins, skill, HitProbability);
|
||||
}
|
||||
else
|
||||
{
|
||||
poiBin = new PoissonBinomial(difficulties, skill, HitProbability);
|
||||
}
|
||||
|
||||
return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - FcProbability, -50, 1000, accuracy: 1e-4));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12,25 +12,27 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
public double Difficulty;
|
||||
public double Count;
|
||||
|
||||
private const int bin_count = 32;
|
||||
|
||||
/// <summary>
|
||||
/// Create an array of equally spaced bins. Count is linearly interpolated into each bin.
|
||||
/// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2,
|
||||
/// we will add 0.8 to the count of 3's and 0.2 to the count of 4's
|
||||
/// </summary>
|
||||
public static Bin[] CreateBins(List<double> difficulties, int binCount)
|
||||
public static Bin[] CreateBins(List<double> difficulties)
|
||||
{
|
||||
double maxDifficulty = difficulties.Max();
|
||||
|
||||
var bins = new Bin[binCount];
|
||||
var bins = new Bin[bin_count];
|
||||
|
||||
for (int i = 0; i < binCount; i++)
|
||||
for (int i = 0; i < bin_count; i++)
|
||||
{
|
||||
bins[i].Difficulty = maxDifficulty * (i + 1) / binCount;
|
||||
bins[i].Difficulty = maxDifficulty * (i + 1) / bin_count;
|
||||
}
|
||||
|
||||
foreach (double d in difficulties)
|
||||
{
|
||||
double binIndex = binCount * (d / maxDifficulty) - 1;
|
||||
double binIndex = bin_count * (d / maxDifficulty) - 1;
|
||||
|
||||
int lowerBound = (int)Math.Floor(binIndex);
|
||||
double t = binIndex - lowerBound;
|
||||
@ -45,7 +47,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
int upperBound = lowerBound + 1;
|
||||
|
||||
// this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway
|
||||
if (upperBound < binCount)
|
||||
if (upperBound < bin_count)
|
||||
{
|
||||
bins[upperBound].Count += t;
|
||||
}
|
||||
|
49
osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs
Normal file
49
osu.Game.Rulesets.Osu/Difficulty/Utils/FitMissCountPoints.cs
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System.Linq;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
public class FitMissCountPoints
|
||||
{
|
||||
// A few operations are precomputed on the Vandermonde matrix of skill values (0, 0.05, 0.1, ..., 0.95, 1).
|
||||
// This is to make regression simpler and less resource heavy. https://discord.com/channels/546120878908506119/1203154944492830791/1232333184306512002
|
||||
private static double[][] precomputedOperationsMatrix => new[]
|
||||
{
|
||||
new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 },
|
||||
new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 },
|
||||
new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 }
|
||||
};
|
||||
|
||||
public static double[] GetPolynomialCoefficients(double[] missCounts)
|
||||
{
|
||||
double[] adjustedMissCounts = missCounts;
|
||||
|
||||
// The polynomial will pass through the point (1, maxMisscount).
|
||||
double maxMissCount = missCounts.Max();
|
||||
|
||||
for (int i = 0; i <= 20; i++)
|
||||
{
|
||||
adjustedMissCounts[i] -= maxMissCount * i / 20;
|
||||
}
|
||||
|
||||
// The precomputed matrix assumes the misscounts go in order of greatest to least.
|
||||
// Temporary fix.
|
||||
adjustedMissCounts = adjustedMissCounts.Reverse().ToArray();
|
||||
|
||||
double[] coefficients = new double[3];
|
||||
|
||||
// Now we dot product the adjusted misscounts with the precomputed matrix.
|
||||
for (int row = 0; row < precomputedOperationsMatrix.Length; row++)
|
||||
{
|
||||
for (int column = 0; column < precomputedOperationsMatrix[row].Length; column++)
|
||||
{
|
||||
coefficients[row] += precomputedOperationsMatrix[row][column] * adjustedMissCounts[column];
|
||||
}
|
||||
}
|
||||
|
||||
return coefficients;
|
||||
}
|
||||
}
|
||||
}
|
119
osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs
Normal file
119
osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
/// <summary>
|
||||
/// Approximation of the Poisson binomial distribution:
|
||||
/// https://en.wikipedia.org/wiki/Poisson_binomial_distribution
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// For the approximation method, see "Refined Normal Approximation (RNA)" from:
|
||||
/// Hong, Y. (2013). On computing the distribution function for the Poisson binomial distribution. Computational Statistics and Data Analysis, Vol. 59, pp. 41-51.
|
||||
/// (https://www.researchgate.net/publication/257017356_On_computing_the_distribution_function_for_the_Poisson_binomial_distribution)
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// This has been verified against a reference implementation provided by the authors in the R package "poibin",
|
||||
/// which can be viewed here:
|
||||
/// https://rdrr.io/cran/poibin/man/poibin-package.html
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public class PoissonBinomial
|
||||
{
|
||||
/// <summary>
|
||||
/// The expected value of the distribution.
|
||||
/// </summary>
|
||||
private readonly double mu;
|
||||
|
||||
/// <summary>
|
||||
/// The standard deviation of the distribution.
|
||||
/// </summary>
|
||||
private readonly double sigma;
|
||||
|
||||
/// <summary>
|
||||
/// The gamma factor from equation (11) in the cited paper, pre-divided by 6 to save on re-computation.
|
||||
/// </summary>
|
||||
private readonly double v;
|
||||
|
||||
/// <summary>
|
||||
/// Creates a Poisson binomial distribution based on N trials with the provided difficulties, skill, and method for getting the miss probabilities.
|
||||
/// </summary>
|
||||
/// <param name="difficulties">The list of difficulties in the map.</param>
|
||||
/// <param name="skill">The skill level to get the miss probabilities with.</param>
|
||||
/// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
|
||||
public PoissonBinomial(IList<double> difficulties, double skill, Func<double, double, double> hitProbability)
|
||||
{
|
||||
double variance = 0;
|
||||
double gamma = 0;
|
||||
|
||||
foreach (double d in difficulties)
|
||||
{
|
||||
double p = 1 - hitProbability(skill, d);
|
||||
|
||||
mu += p;
|
||||
variance += p * (1 - p);
|
||||
gamma += p * (1 - p) * (1 - 2 * p);
|
||||
}
|
||||
|
||||
sigma = Math.Sqrt(variance);
|
||||
|
||||
v = gamma / (6 * Math.Pow(sigma, 3));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a Poisson binomial distribution based on N trials with the provided bins of difficulties, skill, and method for getting the miss probabilities.
|
||||
/// </summary>
|
||||
/// <param name="bins">The bins of difficulties in the map.</param>
|
||||
/// <param name="skill">The skill level to get the miss probabilities with.</param>
|
||||
/// /// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
|
||||
public PoissonBinomial(Bin[] bins, double skill, Func<double, double, double> hitProbability)
|
||||
{
|
||||
double variance = 0;
|
||||
double gamma = 0;
|
||||
|
||||
foreach (Bin bin in bins)
|
||||
{
|
||||
double p = 1 - hitProbability(skill, bin.Difficulty);
|
||||
|
||||
mu += p * bin.Count;
|
||||
variance += p * (1 - p) * bin.Count;
|
||||
gamma += p * (1 - p) * (1 - 2 * p) * bin.Count;
|
||||
}
|
||||
|
||||
sigma = Math.Sqrt(variance);
|
||||
|
||||
v = gamma / (6 * Math.Pow(sigma, 3));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the value of the cumulative distribution function for this Poisson binomial distribution.
|
||||
/// </summary>
|
||||
/// <param name="count">
|
||||
/// The argument of the CDF to sample the distribution for.
|
||||
/// In the discrete case (when it is a whole number), this corresponds to the number
|
||||
/// of successful Bernoulli trials to query the CDF for.
|
||||
/// </param>
|
||||
/// <returns>
|
||||
/// The value of the CDF at <paramref name="count"/>.
|
||||
/// In the discrete case this corresponds to the probability that at most <paramref name="count"/>
|
||||
/// Bernoulli trials ended in a success.
|
||||
/// </returns>
|
||||
// ReSharper disable once InconsistentNaming
|
||||
public double CDF(double count)
|
||||
{
|
||||
if (sigma == 0)
|
||||
return 1;
|
||||
|
||||
double k = (count + 0.5 - mu) / sigma;
|
||||
|
||||
// see equation (14) of the cited paper
|
||||
double result = SpecialFunctions.NormalCdf(0, 1, k) + v * (1 - k * k) * SpecialFunctions.NormalPdf(0, 1, k);
|
||||
|
||||
return Math.Clamp(result, 0, 1);
|
||||
}
|
||||
}
|
||||
}
|
@ -18,6 +18,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
public class SpecialFunctions
|
||||
{
|
||||
private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d;
|
||||
private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d;
|
||||
|
||||
/// <summary>
|
||||
@ -690,5 +691,42 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x.
|
||||
/// </summary>
|
||||
/// <param name="mean">The mean (μ) of the normal distribution.</param>
|
||||
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
|
||||
/// <param name="x">The location at which to compute the density.</param>
|
||||
/// <returns>the density at <paramref name="x"/>.</returns>
|
||||
/// <remarks>MATLAB: normpdf</remarks>
|
||||
public static double NormalPdf(double mean, double stddev, double x)
|
||||
{
|
||||
if (stddev < 0.0)
|
||||
{
|
||||
throw new ArgumentException("Invalid parametrization for the distribution.");
|
||||
}
|
||||
|
||||
double d = (x - mean) / stddev;
|
||||
return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x).
|
||||
/// </summary>
|
||||
/// <param name="x">The location at which to compute the cumulative distribution function.</param>
|
||||
/// <param name="mean">The mean (μ) of the normal distribution.</param>
|
||||
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
|
||||
/// <returns>the cumulative distribution at location <paramref name="x"/>.</returns>
|
||||
/// <remarks>MATLAB: normcdf</remarks>
|
||||
public static double NormalCdf(double mean, double stddev, double x)
|
||||
{
|
||||
if (stddev < 0.0)
|
||||
{
|
||||
throw new ArgumentException("Invalid parametrization for the distribution.");
|
||||
}
|
||||
|
||||
return 0.5 * Erfc((mean - x) / (stddev * sqrt2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user