| using System; |
| using System.Collections.Generic; |
| using Unity.MLAgents.Inference.Utils; |
| using Random = System.Random; |
|
|
| namespace Unity.MLAgents |
| { |
| |
| |
| |
| internal static class SamplerFactory |
| { |
| public static Func<float> CreateUniformSampler(float min, float max, int seed) |
| { |
| Random distr = new Random(seed); |
| return () => min + (float)distr.NextDouble() * (max - min); |
| } |
|
|
| public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed) |
| { |
| RandomNormal distr = new RandomNormal(seed, mean, stddev); |
| return () => (float)distr.NextDouble(); |
| } |
|
|
| public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) |
| { |
| |
| Random distr = new Random(seed); |
| |
| float sumIntervalSizes = 0; |
| |
| int numIntervals = (intervals.Count / 2); |
| |
| float[] intervalSizes = new float[numIntervals]; |
| |
| IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; |
| |
| |
| for (int i = 0; i < numIntervals; i++) |
| { |
| var min = intervals[2 * i]; |
| var max = intervals[2 * i + 1]; |
| var intervalSize = max - min; |
| sumIntervalSizes += intervalSize; |
| intervalSizes[i] = intervalSize; |
| intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; |
| } |
| |
| for (int i = 0; i < numIntervals; i++) |
| { |
| intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; |
| } |
| |
| for (int i = 1; i < numIntervals; i++) |
| { |
| intervalSizes[i] += intervalSizes[i - 1]; |
| } |
| Multinomial intervalDistr = new Multinomial(seed + 1); |
| float MultiRange() |
| { |
| int sampledInterval = intervalDistr.Sample(intervalSizes); |
| return intervalFuncs[sampledInterval].Invoke(); |
| } |
|
|
| return MultiRange; |
| } |
| } |
| } |
|
|