mirror of
https://github.com/ppy/osu.git
synced 2024-12-05 03:03:21 +08:00
Merge 153ff37715
into f09d8f097a
This commit is contained in:
commit
4f68a6e7e7
@ -0,0 +1,155 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using osu.Game.Rulesets.Difficulty.Preprocessing;
|
||||
using osu.Game.Rulesets.Difficulty.Skills;
|
||||
using osu.Game.Rulesets.Mods;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Utils;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation
|
||||
{
|
||||
public abstract class OsuProbabilitySkill : Skill
|
||||
{
|
||||
protected OsuProbabilitySkill(Mod[] mods)
|
||||
: base(mods)
|
||||
{
|
||||
}
|
||||
|
||||
// We assume players have a 2% chance to hit every note in the map.
|
||||
// A higher value of fc_probability increases the influence of difficulty spikes,
|
||||
// while a lower value increases the influence of length and consistent difficulty.
|
||||
private const double fc_probability = 0.02;
|
||||
|
||||
private const int bin_count = 32;
|
||||
|
||||
// The number of difficulties there must be before we can be sure that binning difficulties would not change the output significantly.
|
||||
private double binThreshold => 2 * bin_count;
|
||||
|
||||
private readonly List<double> difficulties = new List<double>();
|
||||
|
||||
/// <summary>
|
||||
/// Returns the strain value at <see cref="DifficultyHitObject"/>. This value is calculated with or without respect to previous objects.
|
||||
/// </summary>
|
||||
protected abstract double StrainValueAt(DifficultyHitObject current);
|
||||
|
||||
public override void Process(DifficultyHitObject current)
|
||||
{
|
||||
difficulties.Add(StrainValueAt(current));
|
||||
}
|
||||
|
||||
protected abstract double HitProbability(double skill, double difficulty);
|
||||
|
||||
private double difficultyValueExact()
|
||||
{
|
||||
double maxDiff = difficulties.Max();
|
||||
if (maxDiff <= 1e-10) return 0;
|
||||
|
||||
const double lower_bound = 0;
|
||||
double upperBoundEstimate = 3.0 * maxDiff;
|
||||
|
||||
double skill = RootFinding.FindRootExpand(
|
||||
skill => fcProbability(skill) - fc_probability,
|
||||
lower_bound,
|
||||
upperBoundEstimate,
|
||||
accuracy: 1e-4);
|
||||
|
||||
return skill;
|
||||
|
||||
double fcProbability(double s)
|
||||
{
|
||||
if (s <= 0) return 0;
|
||||
|
||||
return difficulties.Aggregate<double, double>(1, (current, d) => current * HitProbability(s, d));
|
||||
}
|
||||
}
|
||||
|
||||
private double difficultyValueBinned()
|
||||
{
|
||||
double maxDiff = difficulties.Max();
|
||||
if (maxDiff <= 1e-10) return 0;
|
||||
|
||||
var bins = Bin.CreateBins(difficulties, bin_count);
|
||||
|
||||
const double lower_bound = 0;
|
||||
double upperBoundEstimate = 3.0 * maxDiff;
|
||||
|
||||
double skill = RootFinding.FindRootExpand(
|
||||
skill => fcProbability(skill) - fc_probability,
|
||||
lower_bound,
|
||||
upperBoundEstimate,
|
||||
accuracy: 1e-4);
|
||||
|
||||
return skill;
|
||||
|
||||
double fcProbability(double s)
|
||||
{
|
||||
if (s <= 0) return 0;
|
||||
|
||||
return bins.Aggregate(1.0, (current, bin) => current * Math.Pow(HitProbability(s, bin.Difficulty), bin.Count));
|
||||
}
|
||||
}
|
||||
|
||||
public override double DifficultyValue()
|
||||
{
|
||||
if (difficulties.Count == 0) return 0;
|
||||
|
||||
return difficulties.Count > binThreshold ? difficultyValueBinned() : difficultyValueExact();
|
||||
}
|
||||
|
||||
/// <returns>
|
||||
/// A polynomial fitted to the miss counts at each skill level.
|
||||
/// </returns>
|
||||
public ExpPolynomial GetMissPenaltyCurve()
|
||||
{
|
||||
double[] missCounts = new double[7];
|
||||
double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 };
|
||||
|
||||
ExpPolynomial missPenaltyCurve = new ExpPolynomial();
|
||||
|
||||
// If there are no notes, we just return the curve with all coefficients set to zero.
|
||||
if (difficulties.Count == 0 || difficulties.Max() == 0)
|
||||
return missPenaltyCurve;
|
||||
|
||||
double fcSkill = DifficultyValue();
|
||||
|
||||
var bins = Bin.CreateBins(difficulties, bin_count);
|
||||
|
||||
for (int i = 0; i < penalties.Length; i++)
|
||||
{
|
||||
if (i == 0)
|
||||
{
|
||||
missCounts[i] = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
double penalizedSkill = fcSkill * penalties[i];
|
||||
|
||||
missCounts[i] = getMissCountAtSkill(penalizedSkill, bins);
|
||||
}
|
||||
|
||||
missPenaltyCurve.Fit(missCounts);
|
||||
|
||||
return missPenaltyCurve;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Find the lowest miss count that a player with the provided <paramref name="skill"/> would have a 2% chance of achieving or better.
|
||||
/// </summary>
|
||||
private double getMissCountAtSkill(double skill, List<Bin> bins)
|
||||
{
|
||||
double maxDiff = difficulties.Max();
|
||||
|
||||
if (maxDiff == 0)
|
||||
return 0;
|
||||
if (skill <= 0)
|
||||
return difficulties.Count;
|
||||
|
||||
var poiBin = difficulties.Count > binThreshold ? new PoissonBinomial(bins, skill, HitProbability) : new PoissonBinomial(difficulties, skill, HitProbability);
|
||||
|
||||
return Math.Max(0, RootFinding.FindRootExpand(x => poiBin.CDF(x) - fc_probability, -50, 1000, accuracy: 1e-4));
|
||||
}
|
||||
}
|
||||
}
|
@ -8,7 +8,7 @@ using osu.Game.Rulesets.Mods;
|
||||
using System.Linq;
|
||||
using osu.Framework.Utils;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Aggregation
|
||||
{
|
||||
public abstract class OsuStrainSkill : StrainSkill
|
||||
{
|
@ -8,6 +8,7 @@ using Newtonsoft.Json;
|
||||
using osu.Game.Beatmaps;
|
||||
using osu.Game.Rulesets.Difficulty;
|
||||
using osu.Game.Rulesets.Mods;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Utils;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
{
|
||||
@ -19,6 +20,12 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
[JsonProperty("aim_difficulty")]
|
||||
public double AimDifficulty { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The difficulty corresponding to the aim skill.
|
||||
/// </summary>
|
||||
[JsonProperty("aim_penalty_constants")]
|
||||
public ExpPolynomial AimMissPenaltyCurve { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// The difficulty corresponding to the speed skill.
|
||||
/// </summary>
|
||||
|
@ -11,8 +11,10 @@ using osu.Game.Rulesets.Difficulty;
|
||||
using osu.Game.Rulesets.Difficulty.Preprocessing;
|
||||
using osu.Game.Rulesets.Difficulty.Skills;
|
||||
using osu.Game.Rulesets.Mods;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Preprocessing;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Skills;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Utils;
|
||||
using osu.Game.Rulesets.Osu.Mods;
|
||||
using osu.Game.Rulesets.Osu.Objects;
|
||||
using osu.Game.Rulesets.Osu.Scoring;
|
||||
@ -48,7 +50,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
|
||||
double sliderFactor = aimRating > 0 ? aimRatingNoSliders / aimRating : 1;
|
||||
|
||||
double aimDifficultyStrainCount = ((OsuStrainSkill)skills[0]).CountTopWeightedStrains();
|
||||
ExpPolynomial aimMissPenaltyCurve = ((OsuProbabilitySkill)skills[0]).GetMissPenaltyCurve();
|
||||
double speedDifficultyStrainCount = ((OsuStrainSkill)skills[2]).CountTopWeightedStrains();
|
||||
|
||||
if (mods.Any(m => m is OsuModTouchDevice))
|
||||
@ -103,7 +105,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
SpeedNoteCount = speedNotes,
|
||||
FlashlightDifficulty = flashlightRating,
|
||||
SliderFactor = sliderFactor,
|
||||
AimDifficultStrainCount = aimDifficultyStrainCount,
|
||||
AimMissPenaltyCurve = aimMissPenaltyCurve,
|
||||
SpeedDifficultStrainCount = speedDifficultyStrainCount,
|
||||
ApproachRate = preempt > 1200 ? (1800 - preempt) / 120 : (1200 - preempt) / 150 + 5,
|
||||
OverallDifficulty = (80 - hitWindowGreat) / 6,
|
||||
|
@ -5,7 +5,9 @@ using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using osu.Game.Rulesets.Difficulty;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Skills;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Utils;
|
||||
using osu.Game.Rulesets.Osu.Mods;
|
||||
using osu.Game.Rulesets.Scoring;
|
||||
using osu.Game.Scoring;
|
||||
@ -139,10 +141,9 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
|
||||
double lengthBonus = 0.95 + 0.4 * Math.Min(1.0, totalHits / 2000.0) +
|
||||
(totalHits > 2000 ? Math.Log10(totalHits / 2000.0) * 0.5 : 0.0);
|
||||
aimValue *= lengthBonus;
|
||||
|
||||
if (effectiveMissCount > 0)
|
||||
aimValue *= calculateMissPenalty(effectiveMissCount, attributes.AimDifficultStrainCount);
|
||||
aimValue *= calculateCurveFittedMissPenalty(effectiveMissCount, attributes.AimMissPenaltyCurve);
|
||||
|
||||
double approachRateFactor = 0.0;
|
||||
if (attributes.ApproachRate > 10.33)
|
||||
@ -206,7 +207,7 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
speedValue *= lengthBonus;
|
||||
|
||||
if (effectiveMissCount > 0)
|
||||
speedValue *= calculateMissPenalty(effectiveMissCount, attributes.SpeedDifficultStrainCount);
|
||||
speedValue *= calculateStrainCountMissPenalty(effectiveMissCount, attributes.SpeedDifficultStrainCount);
|
||||
|
||||
double approachRateFactor = 0.0;
|
||||
if (attributes.ApproachRate > 10.33)
|
||||
@ -305,10 +306,17 @@ namespace osu.Game.Rulesets.Osu.Difficulty
|
||||
return flashlightValue;
|
||||
}
|
||||
|
||||
// Miss penalty assumes that a player will miss on the hardest parts of a map,
|
||||
// so we use the amount of relatively difficult sections to adjust miss penalty
|
||||
// to make it more punishing on maps with lower amount of hard sections.
|
||||
private double calculateMissPenalty(double missCount, double difficultStrainCount) => 0.96 / ((missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94))) + 1);
|
||||
// Due to the unavailability of miss location in PP, the following formulas assume that a player will miss on the hardest parts of a map.
|
||||
|
||||
// With the curve fitted miss penalty, we use a pre-computed curve of skill levels for each miss count, raised to the power of 1.5 as
|
||||
// the multiple of the exponents on star rating and PP. This power should be changed if either SR or PP begin to use a different exponent.
|
||||
// As a result, this exponent is not subject to balance.
|
||||
private double calculateCurveFittedMissPenalty(double missCount, ExpPolynomial curve) => Math.Pow(1 - curve.GetPenaltyAt(missCount), 1.5);
|
||||
|
||||
// With the strain count miss penalty, we use the amount of relatively difficult sections to adjust the miss penalty,
|
||||
// to make it more punishing on maps with lower amount of hard sections. This formula is subject to balance.
|
||||
private double calculateStrainCountMissPenalty(double missCount, double difficultStrainCount) => 0.96 / (missCount / (4 * Math.Pow(Math.Log(difficultStrainCount), 0.94)) + 1);
|
||||
|
||||
private double getComboScalingFactor(OsuDifficultyAttributes attributes) => attributes.MaxCombo <= 0 ? 1.0 : Math.Min(Math.Pow(scoreMaxCombo, 0.8) / Math.Pow(attributes.MaxCombo, 0.8), 1.0);
|
||||
private int totalHits => countGreat + countOk + countMeh + countMiss;
|
||||
private int totalImperfectHits => countOk + countMeh + countMiss;
|
||||
|
@ -4,14 +4,16 @@
|
||||
using System;
|
||||
using osu.Game.Rulesets.Difficulty.Preprocessing;
|
||||
using osu.Game.Rulesets.Mods;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Evaluators;
|
||||
using osu.Game.Utils;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
{
|
||||
/// <summary>
|
||||
/// Represents the skill required to correctly aim at every object in the map with a uniform CircleSize and normalized distances.
|
||||
/// </summary>
|
||||
public class Aim : OsuStrainSkill
|
||||
public class Aim : OsuProbabilitySkill
|
||||
{
|
||||
public Aim(Mod[] mods, bool withSliders)
|
||||
: base(mods)
|
||||
@ -23,12 +25,18 @@ namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
|
||||
private double currentStrain;
|
||||
|
||||
private double skillMultiplier => 25.18;
|
||||
private double skillMultiplier => 132;
|
||||
private double strainDecayBase => 0.15;
|
||||
|
||||
private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000);
|
||||
protected override double HitProbability(double skill, double difficulty)
|
||||
{
|
||||
if (difficulty <= 0) return 1;
|
||||
if (skill <= 0) return 0;
|
||||
|
||||
protected override double CalculateInitialStrain(double time, DifficultyHitObject current) => currentStrain * strainDecay(time - current.Previous(0).StartTime);
|
||||
return SpecialFunctions.Erf(skill / (Math.Sqrt(2) * difficulty));
|
||||
}
|
||||
|
||||
private double strainDecay(double ms) => Math.Pow(strainDecayBase, ms / 1000);
|
||||
|
||||
protected override double StrainValueAt(DifficultyHitObject current)
|
||||
{
|
||||
|
@ -7,6 +7,7 @@ using osu.Game.Rulesets.Mods;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Evaluators;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Preprocessing;
|
||||
using System.Linq;
|
||||
using osu.Game.Rulesets.Osu.Difficulty.Aggregation;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Skills
|
||||
{
|
||||
|
62
osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs
Normal file
62
osu.Game.Rulesets.Osu/Difficulty/Utils/Bin.cs
Normal file
@ -0,0 +1,62 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
public struct Bin
|
||||
{
|
||||
public double Difficulty;
|
||||
public double Count;
|
||||
|
||||
/// <summary>
|
||||
/// Create an array of spaced bins. Count is linearly interpolated into each bin.
|
||||
/// For example, if we have bins with values [1,2,3,4,5] and want to insert the value 3.2,
|
||||
/// we will add 0.8 to the count of 3's and 0.2 to the count of 4's
|
||||
/// </summary>
|
||||
public static List<Bin> CreateBins(List<double> difficulties, int totalBins)
|
||||
{
|
||||
double maxDifficulty = difficulties.Max();
|
||||
|
||||
var binsArray = new Bin[totalBins];
|
||||
|
||||
for (int i = 0; i < totalBins; i++)
|
||||
{
|
||||
binsArray[i].Difficulty = maxDifficulty * (i + 1) / totalBins;
|
||||
}
|
||||
|
||||
foreach (double d in difficulties)
|
||||
{
|
||||
double binIndex = totalBins * (d / maxDifficulty) - 1;
|
||||
|
||||
int lowerBound = (int)Math.Floor(binIndex);
|
||||
double t = binIndex - lowerBound;
|
||||
|
||||
//This can be -1, corresponding to the zero difficulty bucket.
|
||||
//We don't store that since it doesn't contribute to difficulty
|
||||
if (lowerBound >= 0)
|
||||
{
|
||||
binsArray[lowerBound].Count += (1 - t);
|
||||
}
|
||||
|
||||
int upperBound = lowerBound + 1;
|
||||
|
||||
// this can be == bin_count for the maximum difficulty object, in which case t will be 0 anyway
|
||||
if (upperBound < totalBins)
|
||||
{
|
||||
binsArray[upperBound].Count += t;
|
||||
}
|
||||
}
|
||||
|
||||
var binsList = binsArray.ToList();
|
||||
|
||||
// For a slight performance improvement, we remove bins that don't contribute to difficulty.
|
||||
binsList.RemoveAll(bin => bin.Count == 0);
|
||||
|
||||
return binsList;
|
||||
}
|
||||
}
|
||||
}
|
86
osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs
Normal file
86
osu.Game.Rulesets.Osu/Difficulty/Utils/ExpPolynomial.cs
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using osu.Game.Utils;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
/// <summary>
|
||||
/// Represents a polynomial fitted to the logarithm of a given set of points.
|
||||
/// The resulting polynomial is exponentiated, ensuring low residuals for
|
||||
/// small inputs while handling exponentially increasing trends in the data.
|
||||
/// This approach is useful for modelling the results of decreasing skill with few coefficients,
|
||||
/// as linear decreases in skill correspond with exponential increases in miss counts.
|
||||
/// </summary>
|
||||
public struct ExpPolynomial
|
||||
{
|
||||
private double[]? coefficients;
|
||||
|
||||
// The matrix that minimizes the square error at X values [0.0, 0.30, 0.60, 0.80, 0.90, 0.95, 1.0].
|
||||
private static readonly double[][] matrix =
|
||||
{
|
||||
new[] { 0.0, 3.14395, 5.18439, 6.46975, 1.4638, -9.53526, 0.0 },
|
||||
new[] { 0.0, -4.85829, -8.09612, -10.4498, -3.84479, 12.1626, 0.0 }
|
||||
};
|
||||
|
||||
/// <summary>
|
||||
/// Computes the coefficients of a quartic polynomial, starting at 0 and ending at the highest miss count in the array.
|
||||
/// </summary>
|
||||
/// <param name="missCounts">A list of miss counts, with X values [1, 0.95, 0.9, 0.8, 0.6, 0.3, 0] corresponding to their skill levels.</param>
|
||||
public void Fit(double[] missCounts)
|
||||
{
|
||||
List<double> logMissCounts = missCounts.Select(x => Math.Log(x + 1)).ToList();
|
||||
|
||||
double endPoint = logMissCounts.Max();
|
||||
|
||||
double[] penalties = { 1, 0.95, 0.9, 0.8, 0.6, 0.3, 0 };
|
||||
|
||||
for (int i = 0; i < logMissCounts.Count; i++)
|
||||
{
|
||||
logMissCounts[i] -= endPoint * (1 - penalties[i]);
|
||||
}
|
||||
|
||||
// The precomputed matrix assumes the miss counts go in order of greatest to least.
|
||||
logMissCounts.Reverse();
|
||||
|
||||
coefficients = new double[4];
|
||||
|
||||
coefficients[3] = endPoint;
|
||||
|
||||
// Now we dot product the adjusted miss counts with the precomputed matrix.
|
||||
for (int row = 0; row < matrix.Length; row++)
|
||||
{
|
||||
for (int column = 0; column < matrix[row].Length; column++)
|
||||
{
|
||||
coefficients[row] += matrix[row][column] * logMissCounts[column];
|
||||
}
|
||||
|
||||
coefficients[2] -= coefficients[row];
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Solve for the miss penalty at a specified miss count.
|
||||
/// </summary>
|
||||
public double GetPenaltyAt(double missCount)
|
||||
{
|
||||
if (coefficients is null)
|
||||
return 1;
|
||||
|
||||
List<double> listCoefficients = coefficients.ToList();
|
||||
listCoefficients.Add(-Math.Log(missCount + 1));
|
||||
|
||||
List<double?> xVals = SpecialFunctions.SolvePolynomialRoots(listCoefficients);
|
||||
|
||||
const double max_error = 1e-7;
|
||||
|
||||
// We find the largest value of x (corresponding to the penalty) found as a root of the function, with a fallback of a 100% penalty if no roots were found.
|
||||
double largestValue = xVals.Where(x => x >= 0 - max_error && x <= 1 + max_error).OrderDescending().FirstOrDefault() ?? 1;
|
||||
|
||||
return Math.Clamp(largestValue, 0, 1);
|
||||
}
|
||||
}
|
||||
}
|
120
osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs
Normal file
120
osu.Game.Rulesets.Osu/Difficulty/Utils/PoissonBinomial.cs
Normal file
@ -0,0 +1,120 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using osu.Game.Utils;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
/// <summary>
|
||||
/// Approximation of the Poisson binomial distribution:
|
||||
/// https://en.wikipedia.org/wiki/Poisson_binomial_distribution
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// For the approximation method, see "Refined Normal Approximation (RNA)" from:
|
||||
/// Hong, Y. (2013). On computing the distribution function for the Poisson binomial distribution. Computational Statistics and Data Analysis, Vol. 59, pp. 41-51.
|
||||
/// (https://www.researchgate.net/publication/257017356_On_computing_the_distribution_function_for_the_Poisson_binomial_distribution)
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// This has been verified against a reference implementation provided by the authors in the R package "poibin",
|
||||
/// which can be viewed here:
|
||||
/// https://rdrr.io/cran/poibin/man/poibin-package.html
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
public class PoissonBinomial
|
||||
{
|
||||
/// <summary>
|
||||
/// The expected value of the distribution.
|
||||
/// </summary>
|
||||
private readonly double mu;
|
||||
|
||||
/// <summary>
|
||||
/// The standard deviation of the distribution.
|
||||
/// </summary>
|
||||
private readonly double sigma;
|
||||
|
||||
/// <summary>
|
||||
/// The gamma factor from equation (11) in the cited paper, pre-divided by 6 to save on re-computation.
|
||||
/// </summary>
|
||||
private readonly double v;
|
||||
|
||||
/// <summary>
|
||||
/// Creates a Poisson binomial distribution based on N trials with the provided difficulties, skill, and method for getting the miss probabilities.
|
||||
/// </summary>
|
||||
/// <param name="difficulties">The list of difficulties in the map.</param>
|
||||
/// <param name="skill">The skill level to get the miss probabilities with.</param>
|
||||
/// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
|
||||
public PoissonBinomial(IList<double> difficulties, double skill, Func<double, double, double> hitProbability)
|
||||
{
|
||||
double variance = 0;
|
||||
double gamma = 0;
|
||||
|
||||
foreach (double d in difficulties)
|
||||
{
|
||||
double p = 1 - hitProbability(skill, d);
|
||||
|
||||
mu += p;
|
||||
variance += p * (1 - p);
|
||||
gamma += p * (1 - p) * (1 - 2 * p);
|
||||
}
|
||||
|
||||
sigma = Math.Sqrt(variance);
|
||||
|
||||
v = gamma / (6 * Math.Pow(sigma, 3));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Creates a Poisson binomial distribution based on N trials with the provided bins of difficulties, skill, and method for getting the miss probabilities.
|
||||
/// </summary>
|
||||
/// <param name="bins">The bins of difficulties in the map.</param>
|
||||
/// <param name="skill">The skill level to get the miss probabilities with.</param>
|
||||
/// /// <param name="hitProbability">Converts difficulties and skill to miss probabilities.</param>
|
||||
public PoissonBinomial(List<Bin> bins, double skill, Func<double, double, double> hitProbability)
|
||||
{
|
||||
double variance = 0;
|
||||
double gamma = 0;
|
||||
|
||||
foreach (Bin bin in bins)
|
||||
{
|
||||
double p = 1 - hitProbability(skill, bin.Difficulty);
|
||||
|
||||
mu += p * bin.Count;
|
||||
variance += p * (1 - p) * bin.Count;
|
||||
gamma += p * (1 - p) * (1 - 2 * p) * bin.Count;
|
||||
}
|
||||
|
||||
sigma = Math.Sqrt(variance);
|
||||
|
||||
v = gamma / (6 * Math.Pow(sigma, 3));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the value of the cumulative distribution function for this Poisson binomial distribution.
|
||||
/// </summary>
|
||||
/// <param name="count">
|
||||
/// The argument of the CDF to sample the distribution for.
|
||||
/// In the discrete case (when it is a whole number), this corresponds to the number
|
||||
/// of successful Bernoulli trials to query the CDF for.
|
||||
/// </param>
|
||||
/// <returns>
|
||||
/// The value of the CDF at <paramref name="count"/>.
|
||||
/// In the discrete case this corresponds to the probability that at most <paramref name="count"/>
|
||||
/// Bernoulli trials ended in a success.
|
||||
/// </returns>
|
||||
// ReSharper disable once InconsistentNaming
|
||||
public double CDF(double count)
|
||||
{
|
||||
if (sigma == 0)
|
||||
return 1;
|
||||
|
||||
double k = (count + 0.5 - mu) / sigma;
|
||||
|
||||
// see equation (14) of the cited paper
|
||||
double result = SpecialFunctions.NormalCdf(0, 1, k) + v * (1 - k * k) * SpecialFunctions.NormalPdf(0, 1, k);
|
||||
|
||||
return Math.Clamp(result, 0, 1);
|
||||
}
|
||||
}
|
||||
}
|
119
osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs
Normal file
119
osu.Game.Rulesets.Osu/Difficulty/Utils/RootFinding.cs
Normal file
@ -0,0 +1,119 @@
|
||||
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
|
||||
// See the LICENCE file in the repository root for full licence text.
|
||||
|
||||
using System;
|
||||
|
||||
namespace osu.Game.Rulesets.Osu.Difficulty.Utils
|
||||
{
|
||||
public static class RootFinding
|
||||
{
|
||||
/// <summary>
|
||||
/// Finds the root of a <paramref name="function"/> using the Chandrupatla method, expanding the bounds if the root is not located within.
|
||||
/// Expansion only occurs for the upward bound, as this function is optimized for functions of range [0, x),
|
||||
/// which is useful for finding skill level (skill can never be below 0).
|
||||
/// </summary>
|
||||
/// <param name="function">The function of which to find the root.</param>
|
||||
/// <param name="guessLowerBound">The lower bound of the function inputs.</param>
|
||||
/// <param name="guessUpperBound">The upper bound of the function inputs.</param>
|
||||
/// <param name="maxIterations">The maximum number of iterations before the function throws an error.</param>
|
||||
/// <param name="accuracy">The desired precision in which the root is returned.</param>
|
||||
/// <param name="expansionFactor">The multiplier on the upper bound when no root is found within the provided bounds.</param>
|
||||
/// <param name="maxExpansions">The maximum number of times the bounds of the function should increase.</param>
|
||||
public static double FindRootExpand(Func<double, double> function, double guessLowerBound, double guessUpperBound, int maxIterations = 25, double accuracy = 1e-6D, double expansionFactor = 2, double maxExpansions = 32)
|
||||
{
|
||||
double a = guessLowerBound;
|
||||
double b = guessUpperBound;
|
||||
double fa = function(a);
|
||||
double fb = function(b);
|
||||
|
||||
int expansions = 0;
|
||||
|
||||
while (fa * fb > 0)
|
||||
{
|
||||
a = b;
|
||||
b *= expansionFactor;
|
||||
fa = function(a);
|
||||
fb = function(b);
|
||||
|
||||
expansions++;
|
||||
|
||||
if (expansions > maxExpansions)
|
||||
{
|
||||
throw new MaximumIterationsException("No root was found within the provided function.");
|
||||
}
|
||||
}
|
||||
|
||||
double t = 0.5;
|
||||
|
||||
for (int i = 0; i < maxIterations; i++)
|
||||
{
|
||||
double xt = a + t * (b - a);
|
||||
double ft = function(xt);
|
||||
|
||||
double c;
|
||||
double fc;
|
||||
|
||||
if (Math.Sign(ft) == Math.Sign(fa))
|
||||
{
|
||||
c = a;
|
||||
fc = fa;
|
||||
}
|
||||
else
|
||||
{
|
||||
c = b;
|
||||
b = a;
|
||||
fc = fb;
|
||||
fb = fa;
|
||||
}
|
||||
|
||||
a = xt;
|
||||
fa = ft;
|
||||
|
||||
double xm, fm;
|
||||
|
||||
if (Math.Abs(fa) < Math.Abs(fb))
|
||||
{
|
||||
xm = a;
|
||||
fm = fa;
|
||||
}
|
||||
else
|
||||
{
|
||||
xm = b;
|
||||
fm = fb;
|
||||
}
|
||||
|
||||
if (fm == 0)
|
||||
return xm;
|
||||
|
||||
double tol = 2 * accuracy * Math.Abs(xm) + 2 * accuracy;
|
||||
double tlim = tol / Math.Abs(b - c);
|
||||
|
||||
if (tlim > 0.5)
|
||||
{
|
||||
return xm;
|
||||
}
|
||||
|
||||
double chi = (a - b) / (c - b);
|
||||
double phi = (fa - fb) / (fc - fb);
|
||||
bool iqi = phi * phi < chi && (1 - phi) * (1 - phi) < chi;
|
||||
|
||||
if (iqi)
|
||||
t = fa / (fb - fa) * fc / (fb - fc) + (c - a) / (b - a) * fa / (fc - fa) * fb / (fc - fb);
|
||||
else
|
||||
t = 0.5;
|
||||
|
||||
t = Math.Min(1 - tlim, Math.Max(tlim, t));
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
private class MaximumIterationsException : Exception
|
||||
{
|
||||
public MaximumIterationsException(string message)
|
||||
: base(message)
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -13,12 +13,17 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI
|
||||
*/
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using osu.Framework.Utils;
|
||||
|
||||
namespace osu.Game.Utils
|
||||
{
|
||||
public class SpecialFunctions
|
||||
{
|
||||
private const double sqrt2 = 1.4142135623730950488016887242096980785696718753769d;
|
||||
private const double sqrt2_pi = 2.5066282746310005024157652848110452530069867406099d;
|
||||
private const double pi_mult_2 = 6.28318530717958647692528676655900576d;
|
||||
|
||||
/// <summary>
|
||||
/// **************************************
|
||||
@ -690,5 +695,280 @@ namespace osu.Game.Utils
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the probability density of the distribution (PDF) at x, i.e. ∂P(X ≤ x)/∂x.
|
||||
/// </summary>
|
||||
/// <param name="mean">The mean (μ) of the normal distribution.</param>
|
||||
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
|
||||
/// <param name="x">The location at which to compute the density.</param>
|
||||
/// <returns>the density at <paramref name="x"/>.</returns>
|
||||
/// <remarks>MATLAB: normpdf</remarks>
|
||||
public static double NormalPdf(double mean, double stddev, double x)
|
||||
{
|
||||
if (stddev < 0.0)
|
||||
{
|
||||
throw new ArgumentException("Invalid parametrization for the distribution.");
|
||||
}
|
||||
|
||||
double d = (x - mean) / stddev;
|
||||
return Math.Exp(-0.5 * d * d) / (sqrt2_pi * stddev);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Computes the cumulative distribution (CDF) of the distribution at x, i.e. P(X ≤ x).
|
||||
/// </summary>
|
||||
/// <param name="x">The location at which to compute the cumulative distribution function.</param>
|
||||
/// <param name="mean">The mean (μ) of the normal distribution.</param>
|
||||
/// <param name="stddev">The standard deviation (σ) of the normal distribution. Range: σ ≥ 0.</param>
|
||||
/// <returns>the cumulative distribution at location <paramref name="x"/>.</returns>
|
||||
/// <remarks>MATLAB: normcdf</remarks>
|
||||
public static double NormalCdf(double mean, double stddev, double x)
|
||||
{
|
||||
if (stddev < 0.0)
|
||||
{
|
||||
throw new ArgumentException("Invalid parametrization for the distribution.");
|
||||
}
|
||||
|
||||
return 0.5 * Erfc((mean - x) / (stddev * sqrt2));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Solve for the exact real roots of any polynomial up to degree 4.
|
||||
/// </summary>
|
||||
/// <param name="coefficients">The coefficients of the polynomial, in ascending order ([1, 3, 5] -> x^2 + 3x + 5).</param>
|
||||
/// <returns>The real roots of the polynomial, and null if the root does not exist.</returns>
|
||||
public static List<double?> SolvePolynomialRoots(List<double> coefficients)
|
||||
{
|
||||
List<double?> xVals = new List<double?>();
|
||||
|
||||
switch (coefficients.Count)
|
||||
{
|
||||
case 5:
|
||||
xVals = solveP4(coefficients[0], coefficients[1], coefficients[2], coefficients[3], coefficients[4], out int _).ToList();
|
||||
break;
|
||||
|
||||
case 4:
|
||||
xVals = solveP3(coefficients[0], coefficients[1], coefficients[2], coefficients[3], out int _).ToList();
|
||||
break;
|
||||
|
||||
case 3:
|
||||
xVals = solveP2(coefficients[0], coefficients[1], coefficients[2], out int _).ToList();
|
||||
break;
|
||||
|
||||
case 2:
|
||||
xVals = solveP2(0, coefficients[1], coefficients[2], out int _).ToList();
|
||||
break;
|
||||
}
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
// https://github.com/sasamil/Quartic/blob/master/quartic.cpp
|
||||
private static double?[] solveP4(double a, double b, double c, double d, double e, out int nRoots)
|
||||
{
|
||||
double?[] xVals = new double?[4];
|
||||
|
||||
nRoots = 0;
|
||||
|
||||
if (a == 0)
|
||||
{
|
||||
double?[] xValsCubic = solveP3(b, c, d, e, out nRoots);
|
||||
|
||||
xVals[0] = xValsCubic[0];
|
||||
xVals[1] = xValsCubic[1];
|
||||
xVals[2] = xValsCubic[2];
|
||||
xVals[3] = null;
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
b /= a;
|
||||
c /= a;
|
||||
d /= a;
|
||||
e /= a;
|
||||
|
||||
double a3 = -c;
|
||||
double b3 = b * d - 4 * e;
|
||||
double c3 = -b * b * e - d * d + 4 * c * e;
|
||||
|
||||
double?[] x3 = solveP3(1, a3, b3, c3, out int iZeroes);
|
||||
|
||||
double q1, q2, p1, p2, sqD;
|
||||
|
||||
double y = x3[0]!.Value;
|
||||
|
||||
// Get the y value with the highest absolute value.
|
||||
if (iZeroes != 1)
|
||||
{
|
||||
if (Math.Abs(x3[1]!.Value) > Math.Abs(y))
|
||||
y = x3[1]!.Value;
|
||||
if (Math.Abs(x3[2]!.Value) > Math.Abs(y))
|
||||
y = x3[2]!.Value;
|
||||
}
|
||||
|
||||
double upperD = y * y - 4 * e;
|
||||
|
||||
if (Precision.AlmostEquals(upperD, 0))
|
||||
{
|
||||
q1 = q2 = y * 0.5;
|
||||
|
||||
upperD = b * b - 4 * (c - y);
|
||||
|
||||
if (Precision.AlmostEquals(upperD, 0))
|
||||
p1 = p2 = b * 0.5;
|
||||
|
||||
else
|
||||
{
|
||||
sqD = Math.Sqrt(upperD);
|
||||
p1 = (b + sqD) * 0.5;
|
||||
p2 = (b - sqD) * 0.5;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
sqD = Math.Sqrt(upperD);
|
||||
q1 = (y + sqD) * 0.5;
|
||||
q2 = (y - sqD) * 0.5;
|
||||
|
||||
p1 = (b * q1 - d) / (q1 - q2);
|
||||
p2 = (d - b * q2) / (q1 - q2);
|
||||
}
|
||||
|
||||
// solving quadratic eq. - x^2 + p1*x + q1 = 0
|
||||
upperD = p1 * p1 - 4 * q1;
|
||||
|
||||
if (upperD >= 0)
|
||||
{
|
||||
nRoots += 2;
|
||||
|
||||
sqD = Math.Sqrt(upperD);
|
||||
xVals[0] = (-p1 + sqD) * 0.5;
|
||||
xVals[1] = (-p1 - sqD) * 0.5;
|
||||
}
|
||||
|
||||
// solving quadratic eq. - x^2 + p2*x + q2 = 0
|
||||
upperD = p2 * p2 - 4 * q2;
|
||||
|
||||
if (upperD >= 0)
|
||||
{
|
||||
nRoots += 2;
|
||||
|
||||
sqD = Math.Sqrt(upperD);
|
||||
xVals[2] = (-p2 + sqD) * 0.5;
|
||||
xVals[3] = (-p2 - sqD) * 0.5;
|
||||
}
|
||||
|
||||
// Put the null roots at the end of the array.
|
||||
var nonNulls = xVals.Where(x => x != null);
|
||||
var nulls = xVals.Where(x => x == null);
|
||||
xVals = nonNulls.Concat(nulls).ToArray();
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
private static double?[] solveP3(double a, double b, double c, double d, out int nRoots)
|
||||
{
|
||||
double?[] xVals = new double?[3];
|
||||
|
||||
nRoots = 0;
|
||||
|
||||
if (a == 0)
|
||||
{
|
||||
double?[] xValsQuadratic = solveP2(b, c, d, out nRoots);
|
||||
|
||||
xVals[0] = xValsQuadratic[0];
|
||||
xVals[1] = xValsQuadratic[1];
|
||||
xVals[2] = null;
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
b /= a;
|
||||
c /= a;
|
||||
d /= a;
|
||||
|
||||
double a2 = b * b;
|
||||
double q = (a2 - 3 * c) / 9;
|
||||
double q3 = q * q * q;
|
||||
double r = (b * (2 * a2 - 9 * c) + 27 * d) / 54;
|
||||
double r2 = r * r;
|
||||
|
||||
if (r2 < q3)
|
||||
{
|
||||
nRoots = 3;
|
||||
|
||||
double t = r / Math.Sqrt(q3);
|
||||
t = Math.Clamp(t, -1, 1);
|
||||
t = Math.Acos(t);
|
||||
b /= 3;
|
||||
q = -2 * Math.Sqrt(q);
|
||||
|
||||
xVals[0] = q * Math.Cos(t / 3) - b;
|
||||
xVals[1] = q * Math.Cos((t + pi_mult_2) / 3) - b;
|
||||
xVals[2] = q * Math.Cos((t - pi_mult_2) / 3) - b;
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
double upperA = -Math.Cbrt(Math.Abs(r) + Math.Sqrt(r2 - q3));
|
||||
|
||||
if (r < 0)
|
||||
upperA = -upperA;
|
||||
|
||||
double upperB = upperA == 0 ? 0 : q / upperA;
|
||||
b /= 3;
|
||||
|
||||
xVals[0] = upperA + upperB - b;
|
||||
|
||||
if (Precision.AlmostEquals(0.5 * Math.Sqrt(3) * (upperA - upperB), 0))
|
||||
{
|
||||
nRoots = 2;
|
||||
xVals[1] = -0.5 * (upperA + upperB) - b;
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
nRoots = 1;
|
||||
|
||||
return xVals;
|
||||
}
|
||||
|
||||
private static double?[] solveP2(double a, double b, double c, out int nRoots)
|
||||
{
|
||||
double?[] xVals = new double?[2];
|
||||
|
||||
nRoots = 0;
|
||||
|
||||
if (a == 0)
|
||||
{
|
||||
if (b == 0)
|
||||
return xVals;
|
||||
|
||||
nRoots = 1;
|
||||
xVals[0] = -c / b;
|
||||
}
|
||||
|
||||
double discriminant = b * b - 4 * a * c;
|
||||
|
||||
switch (discriminant)
|
||||
{
|
||||
case < 0:
|
||||
break;
|
||||
|
||||
case 0:
|
||||
nRoots = 1;
|
||||
xVals[0] = -b / (2 * a);
|
||||
break;
|
||||
|
||||
default:
|
||||
nRoots = 2;
|
||||
xVals[0] = (-b + Math.Sqrt(discriminant)) / (2 * a);
|
||||
xVals[1] = (-b - Math.Sqrt(discriminant)) / (2 * a);
|
||||
break;
|
||||
}
|
||||
|
||||
return xVals;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user