public async Task Create_WithReplenishingLimiterReplenishesAutomatically()
        {
            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                // Use the non-specific Create method to make sure ReplenishingRateLimiters are still handled properly
                return(RateLimitPartition.Get(1,
                                              _ => new TokenBucketRateLimiter(new TokenBucketRateLimiterOptions
                {
                    TokenLimit = 1,
                    QueueProcessingOrder = QueueProcessingOrder.NewestFirst,
                    QueueLimit = 1,
                    ReplenishmentPeriod = TimeSpan.FromMilliseconds(100),
                    TokensPerPeriod = 1,
                    AutoReplenishment = false
                })));
            });

            var lease = limiter.AttemptAcquire("");

            Assert.True(lease.IsAcquired);

            lease = await limiter.AcquireAsync("");

            Assert.True(lease.IsAcquired);
        }
        public async Task Create_DisposeAsyncDisposesAllLimiters()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
                }
                return(RateLimitPartition.Get(2, key => limiterFactory.GetLimiter(key)));
            });

            limiter.Acquire("1");
            limiter.Acquire("2");

            await limiter.DisposeAsync();

            Assert.Equal(2, limiterFactory.Limiters.Count);
            Assert.Equal(1, limiterFactory.Limiters[0].Limiter.AcquireCallCount);
            Assert.Equal(1, limiterFactory.Limiters[0].Limiter.DisposeCallCount);
            Assert.Equal(1, limiterFactory.Limiters[0].Limiter.DisposeAsyncCallCount);

            Assert.Equal(1, limiterFactory.Limiters[1].Limiter.AcquireCallCount);
            Assert.Equal(1, limiterFactory.Limiters[1].Limiter.DisposeCallCount);
            Assert.Equal(1, limiterFactory.Limiters[1].Limiter.DisposeAsyncCallCount);
        }
        public void Create_PassedInEqualityComparerIsUsed()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();
            var equality       = new TestEquality();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
                }
                return(RateLimitPartition.Get(2, key => limiterFactory.GetLimiter(key)));
            }, equality);

            limiter.Acquire("1");
            // GetHashCode to add item to dictionary (skips TryGet for empty dictionary)
            Assert.Equal(0, equality.EqualsCallCount);
            Assert.Equal(1, equality.GetHashCodeCallCount);
            limiter.Acquire("1");
            // GetHashCode and Equal from TryGet to see if item is in dictionary
            Assert.Equal(1, equality.EqualsCallCount);
            Assert.Equal(2, equality.GetHashCodeCallCount);
            limiter.Acquire("2");
            // GetHashCode from TryGet (fails check) and second GetHashCode to add item to dictionary
            Assert.Equal(1, equality.EqualsCallCount);
            Assert.Equal(4, equality.GetHashCodeCallCount);

            Assert.Equal(2, limiterFactory.Limiters.Count);
            Assert.Equal(2, limiterFactory.Limiters[0].Limiter.AcquireCallCount);
            Assert.Equal(1, limiterFactory.Limiters[1].Limiter.AcquireCallCount);
        }
        public async Task Create_BlockingWaitDoesNotBlockOtherPartitions()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
                }
                return(RateLimitPartition.GetConcurrencyLimiter(2,
                                                                _ => new ConcurrencyLimiterOptions(1, QueueProcessingOrder.OldestFirst, 2)));
            });

            var lease = await limiter.WaitAndAcquireAsync("2");

            var wait = limiter.WaitAndAcquireAsync("2");

            Assert.False(wait.IsCompleted);

            // Different partition, should not be blocked by the wait in the other partition
            await limiter.WaitAndAcquireAsync("1");

            lease.Dispose();
            await wait;

            Assert.Equal(1, limiterFactory.Limiters.Count);
            Assert.Equal(0, limiterFactory.Limiters[0].Limiter.AcquireCallCount);
            Assert.Equal(1, limiterFactory.Limiters[0].Limiter.WaitAndAcquireAsyncCallCount);
        }
Example #5
0
        internal static Func <Task> StopTimerAndGetTimerFunc <T>(PartitionedRateLimiter <T> limiter)
        {
            var innerTimer = limiter.GetType().GetField("_timer", Reflection.BindingFlags.NonPublic | Reflection.BindingFlags.Instance);

            Assert.NotNull(innerTimer);
            var timerStopMethod = innerTimer.FieldType.GetMethod("Stop");

            Assert.NotNull(timerStopMethod);
            // Stop the current Timer so it doesn't fire unexpectedly
            timerStopMethod.Invoke(innerTimer.GetValue(limiter), Array.Empty <object>());

            // Create a new Timer object so that disposing the PartitionedRateLimiter doesn't fail with an ODE, but this new Timer wont actually do anything
            var timerCtor = innerTimer.FieldType.GetConstructor(new Type[] { typeof(TimeSpan), typeof(TimeSpan) });

            Assert.NotNull(timerCtor);
            var newTimer = timerCtor.Invoke(new object[] { TimeSpan.FromMinutes(10), TimeSpan.FromMinutes(10) });

            Assert.NotNull(newTimer);
            innerTimer.SetValue(limiter, newTimer);

            var timerLoopMethod = limiter.GetType().GetMethod("Heartbeat", Reflection.BindingFlags.NonPublic | Reflection.BindingFlags.Instance);

            Assert.NotNull(timerLoopMethod);
            return(() => (Task)timerLoopMethod.Invoke(limiter, Array.Empty <object>()));
        }
        public void Create_MultiplePartitionsWork()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
                }
                else
                {
                    return(RateLimitPartition.Get(2, key => limiterFactory.GetLimiter(key)));
                }
            });

            limiter.Acquire("1");
            limiter.Acquire("2");
            limiter.Acquire("1");
            limiter.Acquire("2");

            Assert.Equal(2, limiterFactory.Limiters.Count);

            Assert.Equal(2, limiterFactory.Limiters[0].Limiter.AcquireCallCount);
            Assert.Equal(1, limiterFactory.Limiters[0].Key);

            Assert.Equal(2, limiterFactory.Limiters[1].Limiter.AcquireCallCount);
            Assert.Equal(2, limiterFactory.Limiters[1].Key);
        }
Example #7
0
        public async Task IdleLimiterIsCleanedUp()
        {
            CustomizableLimiter innerLimiter = null;
            var factoryCallCount             = 0;

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.Create(1, _ =>
                {
                    factoryCallCount++;
                    innerLimiter = new CustomizableLimiter();
                    return innerLimiter;
                }));
            });

            var timerLoopMethod = Utils.StopTimerAndGetTimerFunc(limiter);

            var lease = limiter.Acquire("");

            Assert.True(lease.IsAcquired);

            Assert.Equal(1, factoryCallCount);

            var tcs = new TaskCompletionSource <object?>(TaskCreationOptions.RunContinuationsAsynchronously);

            innerLimiter.DisposeAsyncCoreImpl = () =>
            {
                tcs.SetResult(null);
                return(default);
        public async Task Create_MultipleReplenishingLimitersReplenishAutomatically()
        {
            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.GetTokenBucketLimiter(1,
                                                                    _ => new TokenBucketRateLimiterOptions(1, QueueProcessingOrder.NewestFirst, 1, TimeSpan.FromMilliseconds(100), 1, false)));
                }
                return(RateLimitPartition.GetTokenBucketLimiter(2,
                                                                _ => new TokenBucketRateLimiterOptions(1, QueueProcessingOrder.NewestFirst, 1, TimeSpan.FromMilliseconds(100), 1, false)));
            });

            var lease = limiter.Acquire("1");

            Assert.True(lease.IsAcquired);

            lease = await limiter.WaitAndAcquireAsync("1");

            Assert.True(lease.IsAcquired);

            // Creates the second Replenishing limiter
            // Indirectly tests that the cached list of limiters used by the timer is probably updated by making sure a limiter already made use of it before we create a second replenishing one
            lease = limiter.Acquire("2");
            Assert.True(lease.IsAcquired);

            lease = await limiter.WaitAndAcquireAsync("1");

            Assert.True(lease.IsAcquired);
            lease = await limiter.WaitAndAcquireAsync("2");

            Assert.True(lease.IsAcquired);
        }
Example #9
0
    /// <summary>
    /// Creates a new <see cref="RateLimitingMiddleware"/>.
    /// </summary>
    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
    /// <param name="logger">The <see cref="ILogger"/> used for logging.</param>
    /// <param name="options">The options for the middleware.</param>
    public RateLimitingMiddleware(RequestDelegate next, ILogger <RateLimitingMiddleware> logger, IOptions <RateLimiterOptions> options)
    {
        _next = next ?? throw new ArgumentNullException(nameof(next));

        _logger = logger ?? throw new ArgumentNullException(nameof(logger));

        _limiter             = options.Value.Limiter;
        _onRejected          = options.Value.OnRejected;
        _rejectionStatusCode = options.Value.DefaultRejectionStatusCode;
    }
Example #10
0
        // Gets and runs the Heartbeat function on the DefaultPartitionedRateLimiter
        internal static Task RunTimerFunc <T>(PartitionedRateLimiter <T> limiter)
        {
            var innerTimer = limiter.GetType().GetField("_timer", BindingFlags.NonPublic | BindingFlags.Instance);

            Assert.NotNull(innerTimer);

            var timerLoopMethod = limiter.GetType().GetMethod("Heartbeat", BindingFlags.NonPublic | BindingFlags.Instance);

            Assert.NotNull(timerLoopMethod);

            return((Task)timerLoopMethod.Invoke(limiter, Array.Empty <object>()));
        }
        public void Create_DisposeWithoutLimitersNoops()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
            });

            limiter.Dispose();

            Assert.Equal(0, limiterFactory.Limiters.Count);
        }
        public void Create_GetAvailablePermitsCallsUnderlyingPartitionsLimiter()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
            });

            limiter.GetAvailablePermits("");
            Assert.Equal(1, limiterFactory.Limiters.Count);
            Assert.Equal(1, limiterFactory.Limiters[0].Limiter.GetAvailablePermitsCallCount);
        }
        public async Task Create_WaitAsyncCallsUnderlyingPartitionsLimiter()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
            });

            await limiter.WaitAndAcquireAsync("");

            Assert.Equal(1, limiterFactory.Limiters.Count);
            Assert.Equal(1, limiterFactory.Limiters[0].Limiter.WaitAndAcquireAsyncCallCount);
        }
        public async Task Create_BlockingFactoryDoesNotBlockOtherPartitions()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();
            var tcs            = new TaskCompletionSource <object?>(TaskCreationOptions.RunContinuationsAsynchronously);
            var startedTcs     = new TaskCompletionSource <object?>(TaskCreationOptions.RunContinuationsAsynchronously);

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Get(1, key =>
                    {
                        startedTcs.SetResult(null);
                        // block the factory method
                        Assert.True(tcs.Task.Wait(TimeSpan.FromSeconds(10)));
                        return limiterFactory.GetLimiter(key);
                    }));
                }
                return(RateLimitPartition.Get(2,
                                              key => limiterFactory.GetLimiter(key)));
            });

            var lease = await limiter.WaitAndAcquireAsync("2");

            var blockedTask = Task.Run(async() =>
            {
                await limiter.WaitAndAcquireAsync("1");
            });
            await startedTcs.Task;

            // Other partitions aren't blocked
            await limiter.WaitAndAcquireAsync("2");

            // Try to acquire from the blocking limiter, this should wait until the blocking limiter has been resolved and not create a new one
            var blockedTask2 = Task.Run(async() =>
            {
                await limiter.WaitAndAcquireAsync("1");
            });

            // unblock limiter factory
            tcs.SetResult(null);
            await blockedTask;
            await blockedTask2;

            // Only 2 limiters should have been created
            Assert.Equal(2, limiterFactory.Limiters.Count);
            Assert.Equal(2, limiterFactory.Limiters[0].Limiter.WaitAndAcquireAsyncCallCount);
            Assert.Equal(2, limiterFactory.Limiters[1].Limiter.WaitAndAcquireAsyncCallCount);
        }
        public void Create_DisposeThrowsForFutureMethodCalls()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.Get(1, key => limiterFactory.GetLimiter(key)));
            });

            limiter.Dispose();

            Assert.Throws <ObjectDisposedException>(() => limiter.Acquire("1"));

            Assert.Equal(0, limiterFactory.Limiters.Count);
        }
        public async Task Create_WithTokenBucketReplenishesAutomatically()
        {
            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.GetTokenBucketLimiter(1,
                                                                _ => new TokenBucketRateLimiterOptions(1, QueueProcessingOrder.NewestFirst, 1, TimeSpan.FromMilliseconds(100), 1, false)));
            });

            var lease = limiter.Acquire("");

            Assert.True(lease.IsAcquired);

            lease = await limiter.WaitAndAcquireAsync("");

            Assert.True(lease.IsAcquired);
        }
Example #17
0
        public async Task Create_PartitionIsCached()
        {
            var limiterFactory = new TrackingRateLimiterFactory <int>();

            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.Create(1, key => limiterFactory.GetLimiter(key)));
            });

            limiter.Acquire("");
            await limiter.WaitAsync("");

            limiter.Acquire("");
            await limiter.WaitAsync("");

            Assert.Equal(1, limiterFactory.Limiters.Count);
            Assert.Equal(2, limiterFactory.Limiters[0].Limiter.AcquireCallCount);
            Assert.Equal(2, limiterFactory.Limiters[0].Limiter.WaitAsyncCallCount);
        }
        public async Task Create_CancellationTokenPassedToUnderlyingLimiter()
        {
            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.GetConcurrencyLimiter(1,
                                                                _ => new ConcurrencyLimiterOptions(1, QueueProcessingOrder.NewestFirst, 1)));
            });

            var lease = limiter.Acquire("");

            Assert.True(lease.IsAcquired);

            var cts      = new CancellationTokenSource();
            var waitTask = limiter.WaitAndAcquireAsync("", 1, cts.Token);

            Assert.False(waitTask.IsCompleted);
            cts.Cancel();
            await Assert.ThrowsAsync <TaskCanceledException>(async() => await waitTask);
        }
Example #19
0
    /// <summary>
    /// Creates a new <see cref="RateLimitingMiddleware"/>.
    /// </summary>
    /// <param name="next">The <see cref="RequestDelegate"/> representing the next middleware in the pipeline.</param>
    /// <param name="logger">The <see cref="ILogger"/> used for logging.</param>
    /// <param name="options">The options for the middleware.</param>
    /// <param name="serviceProvider">The service provider.</param>
    public RateLimitingMiddleware(RequestDelegate next, ILogger <RateLimitingMiddleware> logger, IOptions <RateLimiterOptions> options, IServiceProvider serviceProvider)
    {
        ArgumentNullException.ThrowIfNull(next);
        ArgumentNullException.ThrowIfNull(logger);
        ArgumentNullException.ThrowIfNull(serviceProvider);

        _next                = next;
        _logger              = logger;
        _defaultOnRejected   = options.Value.OnRejected;
        _rejectionStatusCode = options.Value.RejectionStatusCode;
        _policyMap           = new Dictionary <string, DefaultRateLimiterPolicy>(options.Value.PolicyMap);

        // Activate policies passed to AddPolicy<TPartitionKey, TPolicy>
        foreach (var unactivatedPolicy in options.Value.UnactivatedPolicyMap)
        {
            _policyMap.Add(unactivatedPolicy.Key, unactivatedPolicy.Value(serviceProvider));
        }

        _globalLimiter   = options.Value.GlobalLimiter;
        _endpointLimiter = CreateEndpointLimiter();
    }
Example #20
0
        // Gets and runs the Heartbeat function on the DefaultPartitionedRateLimiter
        internal static Task RunTimerFunc <T>(PartitionedRateLimiter <T> limiter)
        {
            // Use Type.GetType so that trimming can see what type we're reflecting on, but assert it's the one we got
            var limiterTypeDef = Type.GetType("System.Threading.RateLimiting.DefaultPartitionedRateLimiter`2, System.Threading.RateLimiting");
            var limiterType    = limiter.GetType();

            Assert.Equal(limiterTypeDef, limiterType.GetGenericTypeDefinition());
            if (string.Empty.Length > 0)
            {
                limiterType = limiterTypeDef;
            }

            var innerTimer = limiterType.GetField("_timer", BindingFlags.NonPublic | BindingFlags.Instance);

            Assert.NotNull(innerTimer);

            var timerLoopMethod = limiterType.GetMethod("Heartbeat", BindingFlags.NonPublic | BindingFlags.Instance);

            Assert.NotNull(timerLoopMethod);

            return((Task)timerLoopMethod.Invoke(limiter, Array.Empty <object>()));
        }
        public async Task Create_WithTokenBucketReplenishesAutomatically()
        {
            using var limiter = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                return(RateLimitPartition.GetTokenBucketLimiter(1,
                                                                _ => new TokenBucketRateLimiterOptions
                {
                    TokenLimit = 1,
                    QueueProcessingOrder = QueueProcessingOrder.NewestFirst,
                    QueueLimit = 1,
                    ReplenishmentPeriod = TimeSpan.FromMilliseconds(100),
                    TokensPerPeriod = 1,
                    AutoReplenishment = false
                }));
            });

            var lease = limiter.AttemptAcquire("");

            Assert.True(lease.IsAcquired);

            lease = await limiter.AcquireAsync("");

            Assert.True(lease.IsAcquired);
        }
Example #22
0
        public void Create_DisposeWithThrowingDisposes_DisposesAllLimiters()
        {
            var limiter1 = new CustomizableLimiter();

            limiter1.DisposeImpl = _ => throw new Exception();
            var limiter2 = new CustomizableLimiter();

            limiter2.DisposeImpl = _ => throw new Exception();
            using var limiter    = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Create(1, _ => limiter1));
                }
                return(RateLimitPartition.Create(2, _ => limiter2));
            });

            limiter.Acquire("1");
            limiter.Acquire("2");

            var ex = Assert.Throws <AggregateException>(() => limiter.Dispose());

            Assert.Equal(2, ex.InnerExceptions.Count);
        }
        public async Task Create_DisposeAsyncWithThrowingDisposes_DisposesAllLimiters()
        {
            var limiter1 = new CustomizableLimiter();

            limiter1.DisposeAsyncCoreImpl = () => throw new Exception();
            var limiter2 = new CustomizableLimiter();

            limiter2.DisposeAsyncCoreImpl = () => throw new Exception();
            using var limiter             = PartitionedRateLimiter.Create <string, int>(resource =>
            {
                if (resource == "1")
                {
                    return(RateLimitPartition.Get(1, _ => limiter1));
                }
                return(RateLimitPartition.Get(2, _ => limiter2));
            });

            limiter.Acquire("1");
            limiter.Acquire("2");

            var ex = await Assert.ThrowsAsync <AggregateException>(() => limiter.DisposeAsync().AsTask());

            Assert.Equal(2, ex.InnerExceptions.Count);
        }