Implementing extension method WebRequest.GetResponseAsync with support for CancellationToken

前端 未结 1 1889
小蘑菇
小蘑菇 2020-12-15 13:30

The idea here is simple, but the implementation has some interesting nuances. This is the signature of the extension method I would like to implement in .NET 4

1条回答
  •  臣服心动
    2020-12-15 13:53

    Will this implementation actually behave as expected when used with the TPL?

    No.

    1. It will not flag the Task result as cancelled, so the behavior will not be exactly as expected.
    2. In the event of a timeout, the WebException contained in the AggregateException reported by Task.Exception will have the status WebExceptionStatus.RequestCanceled. It should instead be WebExceptionStatus.Timeout.

    I would actually recommend using TaskCompletionSource to implement this. This allows you to write the code without making your own APM style methods:

    public static Task GetResponseAsync(this WebRequest request, CancellationToken token)
    {
        if (request == null)
            throw new ArgumentNullException("request");
    
        bool timeout = false;
        TaskCompletionSource completionSource = new TaskCompletionSource();
    
        AsyncCallback completedCallback =
            result =>
            {
                try
                {
                    completionSource.TrySetResult(request.EndGetResponse(result));
                }
                catch (WebException ex)
                {
                    if (timeout)
                        completionSource.TrySetException(new WebException("No response was received during the time-out period for a request.", WebExceptionStatus.Timeout));
                    else if (token.IsCancellationRequested)
                        completionSource.TrySetCanceled();
                    else
                        completionSource.TrySetException(ex);
                }
                catch (Exception ex)
                {
                    completionSource.TrySetException(ex);
                }
            };
    
        IAsyncResult asyncResult = request.BeginGetResponse(completedCallback, null);
        if (!asyncResult.IsCompleted)
        {
            if (request.Timeout != Timeout.Infinite)
            {
                WaitOrTimerCallback timedOutCallback =
                    (object state, bool timedOut) =>
                    {
                        if (timedOut)
                        {
                            timeout = true;
                            request.Abort();
                        }
                    };
    
                ThreadPool.RegisterWaitForSingleObject(asyncResult.AsyncWaitHandle, timedOutCallback, null, request.Timeout, true);
            }
    
            if (token != CancellationToken.None)
            {
                WaitOrTimerCallback cancelledCallback =
                    (object state, bool timedOut) =>
                    {
                        if (token.IsCancellationRequested)
                            request.Abort();
                    };
    
                ThreadPool.RegisterWaitForSingleObject(token.WaitHandle, cancelledCallback, null, Timeout.Infinite, true);
            }
        }
    
        return completionSource.Task;
    }
    

    The advantage here is that your Task result will work fully as expected (will be flagged as canceled, or raise the same exception with timeout info as synchronous version, etc). This also avoids the overhead of using Task.Factory.FromAsync, since you're already handling most of the difficult work involved there yourself.


    Addendum by 280Z28

    Here is a unit test showing proper operation for the method above.

    [TestClass]
    public class AsyncWebRequestTests
    {
        [TestMethod]
        public void TestAsyncWebRequest()
        {
            Uri uri = new Uri("http://google.com");
            WebRequest request = HttpWebRequest.Create(uri);
            Task response = request.GetResponseAsync();
            response.Wait();
        }
    
        [TestMethod]
        public void TestAsyncWebRequestTimeout()
        {
            Uri uri = new Uri("http://google.com");
            WebRequest request = HttpWebRequest.Create(uri);
            request.Timeout = 0;
            Task response = request.GetResponseAsync();
            try
            {
                response.Wait();
                Assert.Fail("Expected an exception");
            }
            catch (AggregateException exception)
            {
                Assert.AreEqual(TaskStatus.Faulted, response.Status);
    
                ReadOnlyCollection exceptions = exception.InnerExceptions;
                Assert.AreEqual(1, exceptions.Count);
                Assert.IsInstanceOfType(exceptions[0], typeof(WebException));
    
                WebException webException = (WebException)exceptions[0];
                Assert.AreEqual(WebExceptionStatus.Timeout, webException.Status);
            }
        }
    
        [TestMethod]
        public void TestAsyncWebRequestCancellation()
        {
            Uri uri = new Uri("http://google.com");
            WebRequest request = HttpWebRequest.Create(uri);
            CancellationTokenSource cancellationTokenSource = new CancellationTokenSource();
            Task response = request.GetResponseAsync(cancellationTokenSource.Token);
            cancellationTokenSource.Cancel();
            try
            {
                response.Wait();
                Assert.Fail("Expected an exception");
            }
            catch (AggregateException exception)
            {
                Assert.AreEqual(TaskStatus.Canceled, response.Status);
    
                ReadOnlyCollection exceptions = exception.InnerExceptions;
                Assert.AreEqual(1, exceptions.Count);
                Assert.IsInstanceOfType(exceptions[0], typeof(OperationCanceledException));
            }
        }
    
        [TestMethod]
        public void TestAsyncWebRequestError()
        {
            Uri uri = new Uri("http://google.com/fail");
            WebRequest request = HttpWebRequest.Create(uri);
            Task response = request.GetResponseAsync();
            try
            {
                response.Wait();
                Assert.Fail("Expected an exception");
            }
            catch (AggregateException exception)
            {
                Assert.AreEqual(TaskStatus.Faulted, response.Status);
    
                ReadOnlyCollection exceptions = exception.InnerExceptions;
                Assert.AreEqual(1, exceptions.Count);
                Assert.IsInstanceOfType(exceptions[0], typeof(WebException));
    
                WebException webException = (WebException)exceptions[0];
                Assert.AreEqual(HttpStatusCode.NotFound, ((HttpWebResponse)webException.Response).StatusCode);
            }
        }
    }
    

    0 讨论(0)
提交回复
热议问题