How to implement Task.WhenAny() with a predicate

前端 未结 5 1728
半阙折子戏
半阙折子戏 2020-12-06 21:03

I want to execute several asynchronous tasks concurrently. Each task will run an HTTP request that can either complete successfully or throw an exception. I need to aw

相关标签:
5条回答
  • 2020-12-06 21:35
    public static Task<T> GetFirstResult<T>(
        ICollection<Func<CancellationToken, Task<T>>> taskFactories, 
        Predicate<T> predicate) where T : class
    {
        var tcs = new TaskCompletionSource<T>();
        var cts = new CancellationTokenSource();
    
        int completedCount = 0;
        // in case you have a lot of tasks you might need to throttle them 
        //(e.g. so you don't try to send 99999999 requests at the same time)
        // see: http://stackoverflow.com/a/25877042/67824
        foreach (var taskFactory in taskFactories)
        {
            taskFactory(cts.Token).ContinueWith(t => 
            {
                if (t.Exception != null)
                {
                    Console.WriteLine($"Task completed with exception: {t.Exception}");
                }
                else if (predicate(t.Result))
                {
                    cts.Cancel();
                    tcs.TrySetResult(t.Result);
                }
    
                if (Interlocked.Increment(ref completedCount) == taskFactories.Count)
                {
                    tcs.SetException(new InvalidOperationException("All tasks failed"));
                }
    
            }, cts.Token);
        }
    
        return tcs.Task;
    }
    

    Sample usage:

    using System.Net.Http;
    var client = new HttpClient();
    var response = await GetFirstResult(
        new Func<CancellationToken, Task<HttpResponseMessage>>[] 
        {
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
            ct => client.GetAsync("http://microsoft123456.com", ct),
        }, 
        rm => rm.IsSuccessStatusCode);
    Console.WriteLine($"Successful response: {response}");
    
    0 讨论(0)
  • 2020-12-06 21:40

    Wait for any task and return the task if the condition is met. Otherwise wait again for the other tasks until there is no more task to wait for.

    public static async Task<Task> WhenAny( IEnumerable<Task> tasks, Predicate<Task> condition )
    {
        var tasklist = tasks.ToList();
        while ( tasklist.Count > 0 )
        {
            var task = await Task.WhenAny( tasklist );
            if ( condition( task ) )
                return task;
            tasklist.Remove( task );
        }
        return null;
    }
    

    simple check for that

    var tasks = new List<Task> {
        Task.FromException( new Exception() ),
        Task.FromException( new Exception() ),
        Task.FromException( new Exception() ),
        Task.CompletedTask, };
    
    var completedTask = WhenAny( tasks, t => t.Status == TaskStatus.RanToCompletion ).Result;
    
    if ( tasks.IndexOf( completedTask ) != 3 )
        throw new Exception( "not expected" );
    
    0 讨论(0)
  • 2020-12-06 21:40
    public static Task<Task<T>> WhenFirst<T>(IEnumerable<Task<T>> tasks, Func<Task<T>, bool> predicate)
    {
        if (tasks == null) throw new ArgumentNullException(nameof(tasks));
        if (predicate == null) throw new ArgumentNullException(nameof(predicate));
    
        var tasksArray = (tasks as IReadOnlyList<Task<T>>) ?? tasks.ToArray();
        if (tasksArray.Count == 0) throw new ArgumentException("Empty task list", nameof(tasks));
        if (tasksArray.Any(t => t == null)) throw new ArgumentException("Tasks contains a null reference", nameof(tasks));
    
        var tcs = new TaskCompletionSource<Task<T>>();
        var count = tasksArray.Count;
    
        Action<Task<T>> continuation = t =>
            {
                if (predicate(t))
                {
                    tcs.TrySetResult(t);
                }
                if (Interlocked.Decrement(ref count) == 0)
                {
                    tcs.TrySetResult(null);
                }
            };
    
        foreach (var task in tasksArray)
        {
            task.ContinueWith(continuation);
        }
    
        return tcs.Task;
    }
    

    Sample usage:

    var task = await WhenFirst(tasks, t => t.Status == TaskStatus.RanToCompletion);
    
    if (task != null)
        var value = await task;
    

    Note that this doesn't propagate exceptions of failed tasks (just as WhenAny doesn't).

    You can also create a version of this for the non-generic Task.

    0 讨论(0)
  • 2020-12-06 21:41

    Here is an attempted improvement of the excellent Eli Arbel's answer. These are the improved points:

    1. An exception in the predicate is propagated as a fault of the returned task.
    2. The predicate is not called after a task has been accepted as the result.
    3. The predicate is executed in the original SynchronizationContext. This makes it possible to access UI elements (if the WhenFirst method is called from a UI thread)
    4. The source IEnumerable<Task<T>> is enumerated directly, without being converted to an array first.
    public static Task<Task<T>> WhenFirst<T>(IEnumerable<Task<T>> tasks,
        Func<Task<T>, bool> predicate)
    {
        if (tasks == null) throw new ArgumentNullException(nameof(tasks));
        if (predicate == null) throw new ArgumentNullException(nameof(predicate));
    
        var tcs = new TaskCompletionSource<Task<T>>(
            TaskCreationOptions.RunContinuationsAsynchronously);
        var pendingCount = 1; // The initial 1 represents the enumeration itself
        foreach (var task in tasks)
        {
            if (task == null) throw new ArgumentException($"The {nameof(tasks)}" +
                " argument included a null value.", nameof(tasks));
            Interlocked.Increment(ref pendingCount);
            HandleTaskCompletion(task);
        }
        if (Interlocked.Decrement(ref pendingCount) == 0) tcs.TrySetResult(null);
        return tcs.Task;
    
        async void HandleTaskCompletion(Task<T> task)
        {
            try
            {
                await task; // Continue on the captured context
            }
            catch { } // Ignore exception
    
            if (tcs.Task.IsCompleted) return;
    
            try
            {
                if (predicate(task))
                    tcs.TrySetResult(task);
                else
                    if (Interlocked.Decrement(ref pendingCount) == 0)
                        tcs.TrySetResult(null);
            }
            catch (Exception ex)
            {
                tcs.TrySetException(ex);
            }
        }
    }
    
    0 讨论(0)
  • 2020-12-06 21:48

    Another way of doing this, very similar to Sir Rufo's answer, but using AsyncEnumerable and Ix.NET

    Implement a little helper method to stream any task as soon as it's completed:

    static IAsyncEnumerable<Task<T>> WhenCompleted<T>(IEnumerable<Task<T>> source) =>
        AsyncEnumerable.Create(_ =>
        {
            var tasks = source.ToList();
            Task<T> current = null;
            return AsyncEnumerator.Create(
                async () => tasks.Any() && tasks.Remove(current = await Task.WhenAny(tasks)), 
                () => current,
                async () => { });
        });
    }
    

    One can then process the tasks in completion order, e.g. returning the first matching one as requested:

    await WhenCompleted(tasks).FirstOrDefault(t => t.Status == TaskStatus.RanToCompletion)
    
    0 讨论(0)
提交回复
热议问题