1
0
mirror of https://github.com/ppy/osu.git synced 2024-12-05 09:42:54 +08:00

Implement quartic curve fitting miss penalty

This commit is contained in:
Nathen 2024-04-23 19:20:36 -04:00
parent fd53423121
commit ef0797ad30
9 changed files with 293 additions and 12 deletions

View File

@ -19,6 +19,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty
[JsonProperty("aim_difficulty")]
public double AimDifficulty { get; set; }
/// <summary>
/// The difficulty corresponding to the aim skill.
/// </summary>
[JsonProperty("aim_penalty_constants")]
public (double, double, double) AimPenaltyConstants { get; set; }
/// <summary>
/// The difficulty corresponding to the speed skill.
/// </summary>

View File

@ -37,6 +37,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
return new OsuDifficultyAttributes { Mods = mods };
double aimRating = Math.Sqrt(skills[0].DifficultyValue()) * difficulty_multiplier;
(double, double, double) aimPenaltyConstants = ((Aim)skills[0]).GetMissCountCoefficients();
double aimRatingNoSliders = Math.Sqrt(skills[1].DifficultyValue()) * difficulty_multiplier;
double speedRating = Math.Sqrt(skills[2].DifficultyValue()) * difficulty_multiplier;
double speedNotes = ((Speed)skills[2]).RelevantNoteCount();
@ -97,6 +98,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
StarRating = starRating,
Mods = mods,
AimDifficulty = aimRating,
AimPenaltyConstants = aimPenaltyConstants,
SpeedDifficulty = speedRating,
SpeedNoteCount = speedNotes,
FlashlightDifficulty = flashlightRating,

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using osu.Game.Rulesets.Difficulty;
using osu.Game.Rulesets.Osu.Difficulty.Utils;
using osu.Game.Rulesets.Osu.Mods;
using osu.Game.Rulesets.Scoring;
using osu.Game.Scoring;
@ -93,9 +94,16 @@ namespace osu.Game.Rulesets.Osu.Difficulty
// Penalize misses by assessing # of misses relative to the total # of objects. Default a 3% reduction for any # of misses.
if (effectiveMissCount > 0)
aimValue *= 0.97 * Math.Pow(1 - Math.Pow(effectiveMissCount / totalHits, 0.775), effectiveMissCount);
{
double a = attributes.AimPenaltyConstants.Item1;
double b = attributes.AimPenaltyConstants.Item2;
double c = attributes.AimPenaltyConstants.Item3;
double d = Math.Log(totalHits + 1) - a - b - c;
aimValue *= getComboScalingFactor(attributes);
double penalty = Math.Pow(1 - RootFinding.FindRootExpand(x => Math.Exp(a * x * x * x * x + b * x * x * x + c * x * x + d * x) - 1 - effectiveMissCount, 0, 1), 1.5);
aimValue *= penalty;
}
double approachRateFactor = 0.0;
if (attributes.ApproachRate > 10.33)

View File

@ -46,5 +46,37 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
return currentStrain;
}
/// <summary>
/// The coefficients of a quartic fitted to the miss counts at each skill level.
/// </summary>
/// <returns>The coefficients for ax^4+bx^3+cx^2. The 4th coefficient for dx^1 can be deduced from the first 3 in the performance calculator.</returns>
public (double, double, double) GetMissCountCoefficients()
{
const int count = 21;
const double penalty_per_misscount = 1.0 / (count - 1);
double fcSkill = DifficultyValue();
double[] misscounts = new double[count];
for (int i = 0; i < count; i++)
{
if (i == 0)
{
misscounts[i] = 0;
continue;
}
double penalizedSkill = fcSkill - fcSkill * penalty_per_misscount * i;
// Save misscounts as log form to give higher weight to lower values. Add 1 so that the lowest misscounts remain above 0.
misscounts[i] = Math.Log(GetMissCountAtSkill(penalizedSkill) + 1);
}
double[] constants = FitMissCountPoints.GetPolynomialCoefficients(misscounts);
return (constants[0], constants[1], constants[2]);
}
}
}

View File

@ -22,8 +22,6 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
/// A higher value rewards short, high difficulty sections, whereas a lower value rewards consistent, lower difficulty.
protected abstract double FcProbability { get; }
private const int bin_count = 32;
private readonly List<double> difficulties = new List<double>();
/// <summary>
@ -43,7 +41,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
double maxDiff = difficulties.Max();
if (maxDiff <= 1e-10) return 0;
var bins = Bin.CreateBins(difficulties, bin_count);
var bins = Bin.CreateBins(difficulties);
const double lower_bound = 0;
double upperBoundEstimate = 3.0 * maxDiff;
@ -93,7 +91,34 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
if (difficulties.Count == 0)
return 0;
return difficulties.Count < 2 * bin_count ? difficultyValueExact() : difficultyValueBinned();
return difficulties.Count < 64 ? difficultyValueExact() : difficultyValueBinned();
}
/// <summary>
/// Find the lowest misscount that a player with the provided <paramref name="skill"/> would have a 2% chance of achieving.
/// </summary>
public double GetMissCountAtSkill(double skill)
{
double maxDiff = difficulties.Max();
if (maxDiff == 0)
return 0;
if (skill <= 0)
return difficulties.Count;
PoissonBinomial poiBin;
if (difficulties.Count > 64)
{
var bins = Bin.CreateBins(difficulties);
poiBin = new PoissonBinomial(bins, skill, HitProbability);
}
else
{
poiBin = new PoissonBinomial(difficulties, skill, HitProbability);
}
return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - FcProbability, -50, 1000, accuracy: 1e-4));
}
}
}

View File

@ -12,25 +12,27 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
public double Difficulty;
public double Count;
private const int bin_count = 32;
/// <summary>
/// Create an array of equally spaced bins. Count is linearly interpolated into each bin.
/// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2,
/// we will add 0.8 to the count of 3's and 0.2 to the count of 4's
/// </summary>
public static Bin[] CreateBins(List<double> difficulties, int binCount)
public static Bin[] CreateBins(List<double> difficulties)
{
double maxDifficulty = difficulties.Max();
var bins = new Bin[binCount];
var bins = new Bin[bin_count];
for (int i = 0; i < binCount; i++)
for (int i = 0; i < bin_count; i++)
{
bins[i].Difficulty = maxDifficulty * (i + 1) / binCount;
bins[i].Difficulty = maxDifficulty * (i + 1) / bin_count;
}
foreach (double d in difficulties)
{
double binIndex = binCount * (d / maxDifficulty) - 1;
double binIndex = bin_count * (d / maxDifficulty) - 1;
int lowerBound = (int)Math.Floor(binIndex);
double t = binIndex - lowerBound;
@ -45,7 +47,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
int upperBound = lowerBound + 1;
// this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway
if (upperBound < binCount)
if (upperBound < bin_count)
{
bins[upperBound].Count += t;
}

View File

@ -0,0 +1,49 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.Linq;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
public class FitMissCountPoints
{
// A few operations are precomputed on the Vandermonde matrix of skill values (0, 0.05, 0.1, ..., 0.95, 1).
// This is to make regression simpler and less resource heavy. https://discord.com/channels/546120878908506119/1203154944492830791/1232333184306512002
private static double[][] precomputedOperationsMatrix => new[]
{
new[] { 0.0, -6.99428, -9.87548, -9.76922, -7.66867, -4.43461, -0.795376, 2.65313, 5.4474, 7.2564, 7.88146, 7.2564, 5.4474, 2.65313, -0.795376, -4.43461, -7.66867, -9.76922, -9.87548, -6.99428, 0.0 },
new[] { 0.0, 13.0907, 18.2388, 17.6639, 13.3211, 6.90022, -0.173479, -6.73969, -11.9029, -15.0326, -15.7629, -13.993, -9.88668, -3.87281, 3.35498, 10.8382, 17.3536, 21.4129, 21.2632, 14.8864, 0.0 },
new[] { 0.0, -7.21754, -9.85841, -9.24217, -6.5276, -2.71265, 1.36553, 5.03057, 7.76692, 9.21984, 9.19538, 7.66039, 4.74253, 0.730255, -3.92717, -8.61967, -12.5764, -14.8657, -14.395, -9.91114, 0.0 }
};
public static double[] GetPolynomialCoefficients(double[] missCounts)
{
double[] adjustedMissCounts = missCounts;
// The polynomial will pass through the point (1, maxMisscount).
double maxMissCount = missCounts.Max();
for (int i = 0; i <= 20; i++)
{
adjustedMissCounts[i] -= maxMissCount * i / 20;
}
// The precomputed matrix assumes the misscounts go in order of greatest to least.
// Temporary fix.
adjustedMissCounts = adjustedMissCounts.Reverse().ToArray();
double[] coefficients = new double[3];
// Now we dot product the adjusted misscounts with the precomputed matrix.
for (int row = 0; row < precomputedOperationsMatrix.Length; row++)
{
for (int column = 0; column < precomputedOperationsMatrix[row].Length; column++)
{
coefficients[row] += precomputedOperationsMatrix[row][column] * adjustedMissCounts[column];
}
}
return coefficients;
}
}
}

View File

@ -0,0 +1,119 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
/// <summary>
/// Approximation of the Poisson binomial distribution:
/// https://en.wikipedia.org/wiki/Poisson_binomial_distribution
/// </summary>
/// <remarks>
/// <para>
/// For the approximation method, see "Refined Normal Approximation (RNA)" from:
/// Hong, Y. (2013). On computing the distribution function for the Poisson binomial distribution. Computational Statistics and Data Analysis, Vol. 59, pp. 41-51.
/// (https://www.researchgate.net/publication/257017356_On_computing_the_distribution_function_for_the_Poisson_binomial_distribution)
/// </para>
/// <para>
/// This has been verified against a reference implementation provided by the authors in the R package "poibin",
/// which can be viewed here:
/// https://rdrr.io/cran/poibin/man/poibin-package.html
/// </para>
/// </remarks>
public class PoissonBinomial
{
/// <summary>
/// The expected value of the distribution.
/// </summary>
private readonly double mu;
/// <summary>
/// The standard deviation of the distribution.
/// </summary>
private readonly double sigma;
/// <summary>
/// The gamma factor from equation (11) in the cited paper, pre-divided by 6 to save on re-computation.
/// </summary>
private readonly double v;
/// <summary>
/// Creates a Poisson binomial distribution based on N trials with the provided difficulties, skill, and method for getting the miss probabilities.
/// </summary>
/// <param name="difficulties">The list of difficulties in the map.</param>
/// <param name="skill">The skill level to get the miss probabilities with.</param>
/// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
public PoissonBinomial(IList<double> difficulties, double skill, Func<double, double, double> hitProbability)
{
double variance = 0;
double gamma = 0;
foreach (double d in difficulties)
{
double p = 1 - hitProbability(skill, d);
mu += p;
variance += p * (1 - p);
gamma += p * (1 - p) * (1 - 2 * p);
}
sigma = Math.Sqrt(variance);
v = gamma / (6 * Math.Pow(sigma, 3));
}
/// <summary>
/// Creates a Poisson binomial distribution based on N trials with the provided bins of difficulties, skill, and method for getting the miss probabilities.
/// </summary>
/// <param name="bins">The bins of difficulties in the map.</param>
/// <param name="skill">The skill level to get the miss probabilities with.</param>
/// /// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
public PoissonBinomial(Bin[] bins, double skill, Func<double, double, double> hitProbability)
{
double variance = 0;
double gamma = 0;
foreach (Bin bin in bins)
{
double p = 1 - hitProbability(skill, bin.Difficulty);
mu += p * bin.Count;
variance += p * (1 - p) * bin.Count;
gamma += p * (1 - p) * (1 - 2 * p) * bin.Count;
}
sigma = Math.Sqrt(variance);
v = gamma / (6 * Math.Pow(sigma, 3));
}
/// <summary>
/// Computes the value of the cumulative distribution function for this Poisson binomial distribution.
/// </summary>
/// <param name="count">
/// The argument of the CDF to sample the distribution for.
/// In the discrete case (when it is a whole number), this corresponds to the number
/// of successful Bernoulli trials to query the CDF for.
/// </param>
/// <returns>
/// The value of the CDF at <paramref name="count"/>.
/// In the discrete case this corresponds to the probability that at most <paramref name="count"/>
/// Bernoulli trials ended in a success.
/// </returns>
// ReSharper disable once InconsistentNaming
public double CDF(double count)
{
if (sigma == 0)
return 1;
double k = (count + 0.5 - mu) / sigma;
// see equation (14) of the cited paper
double result = SpecialFunctions.NormalCdf(0, 1, k) + v * (1 - k * k) * SpecialFunctions.NormalPdf(0, 1, k);
return Math.Clamp(result, 0, 1);
}
}
}

View File

@ -18,6 +18,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
{
public class SpecialFunctions
{
private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d;
private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d;
/// <summary>
@ -690,5 +691,42 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Utils
return sum;
}
/// <summary>
/// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x.
/// </summary>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
/// <param name="x">The location at which to compute the density.</param>
/// <returns>the density at <paramref name="x"/>.</returns>
/// <remarks>MATLAB: normpdf</remarks>
public static double NormalPdf(double mean, double stddev, double x)
{
if (stddev < 0.0)
{
throw new ArgumentException("Invalid parametrization for the distribution.");
}
double d = (x - mean) / stddev;
return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev);
}
/// <summary>
/// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x).
/// </summary>
/// <param name="x">The location at which to compute the cumulative distribution function.</param>
/// <param name="mean">The mean (μ) of the normal distribution.</param>
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
/// <returns>the cumulative distribution at location <paramref name="x"/>.</returns>
/// <remarks>MATLAB: normcdf</remarks>
public static double NormalCdf(double mean, double stddev, double x)
{
if (stddev < 0.0)
{
throw new ArgumentException("Invalid parametrization for the distribution.");
}
return 0.5 * Erfc((mean - x) / (stddev * sqrt2));
}
}
}