How to throttle WCF service per client

前端 未结 3 1994
傲寒
傲寒 2020-12-19 22:51

I\'m developing a service that will be exposed on the internet to only a few select clients. However I don\'t want one client to be able to call the service so often that th

相关标签:
3条回答
  • 2020-12-19 23:31

    Firstly if you are using some kind of load balancer it is best to implement it there. For instance NGINX have rate limiting capabilities: http://nginx.org/en/docs/http/ngx_http_limit_req_module.html .

    Secondly you should consider using built in IIS rate limiting capabilities called dynamic IP restrictions: http://www.iis.net/learn/get-started/whats-new-in-iis-8/iis-80-dynamic-ip-address-restrictions .

    If both of this are not enough for you because you need custom logic for this you can always implement it on the application level. This can be done in multiple ways.

    Let's start with some reusable rate limiting logic:

    public interface IRateLimiter
    {
        bool ShouldLimit(string key);
    
        HttpStatusCode LimitStatusCode { get; }
    }
    
    public interface IRateLimiterConfiguration
    {
        int Treshhold { get; set; }
        TimeSpan TimePeriod { get; set; }
        HttpStatusCode LimitStatusCode { get; set; }
    }
    
    public class RateLimiterConfiguration : System.Configuration.ConfigurationSection, IRateLimiterConfiguration
    {
        private const string TimePeriodConst = "timePeriod";
        private const string LimitStatusCodeConst = "limitStatusCode";
        private const string TreshholdConst = "treshhold";
        private const string RateLimiterTypeConst = "rateLimiterType";
    
        [ConfigurationProperty(TreshholdConst, IsRequired = true, DefaultValue = 10)]
        public int Treshhold
        {
            get { return (int)this[TreshholdConst]; }
            set { this[TreshholdConst] = value; }
        }
    
        [ConfigurationProperty(TimePeriodConst, IsRequired = true)]
        [TypeConverter(typeof(TimeSpanConverter))]
        public TimeSpan TimePeriod
        {
            get { return (TimeSpan)this[TimePeriodConst]; }
            set { this[TimePeriodConst] = value; }
        }
    
        [ConfigurationProperty(LimitStatusCodeConst, IsRequired = false, DefaultValue = HttpStatusCode.Forbidden)]
        public HttpStatusCode LimitStatusCode
        {
            get { return (HttpStatusCode)this[LimitStatusCodeConst]; }
            set { this[LimitStatusCodeConst] = value; }
        }
    
        [ConfigurationProperty(RateLimiterTypeConst, IsRequired = true)]
        [TypeConverter(typeof(TypeNameConverter))]
        public Type RateLimiterType
        {
            get { return (Type)this[RateLimiterTypeConst]; }
            set { this[RateLimiterTypeConst] = value; }
        }
    }
    
    public class RateLimiter : IRateLimiter
    {
        private readonly IRateLimiterConfiguration _configuration;        
        private static readonly MemoryCache MemoryCache = MemoryCache.Default;
    
        public RateLimiter(IRateLimiterConfiguration configuration)
        {
            _configuration = configuration;
        }
    
        public virtual bool ShouldLimit(string key)
        {
            if (!string.IsNullOrEmpty(key))
            {
                Counter counter = new Counter {Count = 1};
                counter = MemoryCache.AddOrGetExisting(key, new Counter { Count = 1 }, DateTimeOffset.Now.Add(_configuration.TimePeriod)) as Counter ?? counter;
                lock (counter.LockObject)
                {
                    if (counter.Count < _configuration.Treshhold)
                    {
                        counter.Count++;
                    }
                    else
                    {
                        return true;
                    }
                }
            }
    
            return false;
        }
    
        public HttpStatusCode LimitStatusCode
        {
            get { return _configuration.LimitStatusCode; }
        }
    
        private class Counter
        {
            public volatile int Count;
            public readonly object LockObject = new object();
        }
    }
    
    public class RateLimiterFactory
    {
        public IRateLimiter CreateRateLimiter()
        {
            var configuration = GetConfiguration();
            return (IRateLimiter)Activator.CreateInstance(configuration.RateLimiterType, configuration);
        }
    
        public static RateLimiterConfiguration GetConfiguration()
        {
            return ConfigurationManager.GetSection("rateLimiter") as RateLimiterConfiguration ?? new RateLimiterConfiguration();
        }
    }
    
    static class GetClientIpExtensions
    {
        private const string XForwardedForHeaderName = "X-Forwarded-For";
        private const string HttpXForwardedForServerVariableName = "HTTP_X_FORWARDED_FOR";
        private const string HttpRemoteAddressServerVariableName = "REMOTE_ADDR";
    
        public static string GetClientIp(this Message message)
        {
            return GetClientIp(message.Properties);
        }
    
        public static string GetClientIp(this OperationContext context)
        {
            return GetClientIp(context.IncomingMessageProperties);
        }
    
        public static string GetClientIp(this MessageProperties messageProperties)
        {
            var endpointLoadBalancer = messageProperties[HttpRequestMessageProperty.Name] as HttpRequestMessageProperty;
            if (endpointLoadBalancer != null && endpointLoadBalancer.Headers[XForwardedForHeaderName] != null)
            {
                return endpointLoadBalancer.Headers[XForwardedForHeaderName];
            }
            else
            {
                var endpointProperty = messageProperties[RemoteEndpointMessageProperty.Name] as RemoteEndpointMessageProperty;
                return (endpointProperty == null) ? string.Empty : endpointProperty.Address;
            }
        }
    
        public static string GetClientIp(this HttpRequest request)
        {
            string ipList = request.ServerVariables[HttpXForwardedForServerVariableName];
            return !string.IsNullOrEmpty(ipList) ? ipList.Split(',')[0] : request.ServerVariables[HttpRemoteAddressServerVariableName];
        }
    }
    

    This uses configuration, proper segregation using interfaces and default MemoryCache. You can easily change the implementation to abstract the cache. This would allow to use different cache providers like redis for example. This could be useful if you want to have a distributed cache for multiple servers running same service.

    Now having this code as a base we can add some implementations using it. We can add a IHttpModule:

    public class RateLimiterHttpModule : IHttpModule
    {
        private readonly IRateLimiter _rateLimiter;
    
        public RateLimiterHttpModule()
        {
            _rateLimiter = new RateLimiterFactory().CreateRateLimiter();
        }
    
        public void Init(HttpApplication context)
        {
            context.BeginRequest += OnBeginRequest;
        }
    
        private void OnBeginRequest(object sender, EventArgs e)
        {
            HttpApplication application = (HttpApplication)sender;
            string ip = application.Context.Request.GetClientIp();
            if (_rateLimiter.ShouldLimit(ip))
            {
                TerminateRequest(application.Context.Response);
            }
        }
    
        private void TerminateRequest(HttpResponse httpResponse)
        {
            httpResponse.StatusCode = (int)_rateLimiter.LimitStatusCode;
            httpResponse.SuppressContent = true;
            httpResponse.End();
        }
    
        public void Dispose()
        {
        }
    }
    

    Or a WCF only implementation that will work for any transport level:

    public class RateLimiterDispatchMessageInspector : IDispatchMessageInspector
    {
        private readonly IRateLimiter _rateLimiter;
    
        public RateLimiterDispatchMessageInspector(IRateLimiter rateLimiter)
        {
            _rateLimiter = rateLimiter;
        }
    
        public object AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
        {
            if (_rateLimiter.ShouldLimit(request.GetClientIp()))
            {
                request = null;
                return _rateLimiter.LimitStatusCode;
            }
            return null;
        }
    
        public void BeforeSendReply(ref Message reply, object correlationState)
        {
            if (correlationState is HttpStatusCode)
            {
                HttpResponseMessageProperty responseProperty = new HttpResponseMessageProperty();
                reply.Properties["httpResponse"] = responseProperty;
                responseProperty.StatusCode = (HttpStatusCode)correlationState;
            }
        }
    }
    
    public class RateLimiterServiceBehavior : IServiceBehavior
    {
        public void Validate(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase) { }
    
        public void AddBindingParameters(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase, Collection<ServiceEndpoint> endpoints, BindingParameterCollection bindingParameters) { }
    
        public void ApplyDispatchBehavior(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase)
        {
            var rateLimiterFactory = new RateLimiterFactory();
    
            foreach (ChannelDispatcher chDisp in serviceHostBase.ChannelDispatchers)
            {
                foreach (EndpointDispatcher epDisp in chDisp.Endpoints)
                {
                    epDisp.DispatchRuntime.MessageInspectors.Add(new RateLimiterDispatchMessageInspector(rateLimiterFactory.CreateRateLimiter()));
                }
            }
        }
    }
    
    public class RateLimiterBehaviorExtensionElement : BehaviorExtensionElement
    {
        protected override object CreateBehavior()
        {
            return new RateLimiterServiceBehavior();
        }
    
        public override Type BehaviorType
        {
            get { return typeof(RateLimiterServiceBehavior); }
        }
    }
    

    You can similarly do an action filter for ASP.NET MCV. Check it out here: How do I implement rate limiting in an ASP.NET MVC site? .

    0 讨论(0)
  • 2020-12-19 23:32

    You could alter your configuration like this:

     ConcurrencyMode:=ConcurrencyMode.Single
     InstanceContextMode:=InstanceContextMode.Single
    

    Then, in code, set up two service-level variables:

    • one string variable to hold the ID of the last requestor
    • one integer variable for the number of visits.

    With each request where the ID of the inbound user == to last saved user, increment +1 your integer variable. After request 10, return a denial to the user. If the user's different, reset the variables and process the request.

    It's not a configuration solution - it's configuration and code, but it would work.

    0 讨论(0)
  • 2020-12-19 23:36

    The phrase you're looking for here is Rate limiting. And, no, there's no built-in way to rate-limit a WCF service. As you said, you can play with the WCF feature set around service throttling, but this is a service-level setting and not per-client.

    In order to implement rate limiting the general guidance seems to be to use an in-memory collection (or something like redis for scale-out scenarios) to perform fast look-ups against the incoming user string or IP address. Then you can define some limiting algorithm around that information.

    More info here and here.

    0 讨论(0)
提交回复
热议问题