What happened to AddOrUpdate in EF 7 / Core?

后端 未结 11 1258
暖寄归人
暖寄归人 2020-12-16 09:30

I\'m writing a seed method using EntityFramework.Core 7.0.0-rc1-final.

What happened to the AddOrUpdate method of DbSet?

相关标签:
11条回答
  • 2020-12-16 10:07

    There is an extension method Upsert.

    context.Upsert(new Role { Name = "Employee", NormalizedName = "employee" })
           .On(r => new { r.Name })
           .Run();
    

    On Github

    0 讨论(0)
  • 2020-12-16 10:08

    The following MS Docs article, Disconnected entities, says that just using an Update will act as an AddOrUpdate from EF Core 2.0 onwards, as long as the primary key column in the database has an auto-generated (eg identity) value.

    To quote from the article:

    If it is known whether or not an insert or update is needed, then either Add or Update can be used appropriately.

    However, if the entity uses auto-generated key values, then the Update method can be used for both cases.

    The Update method normally marks the entity for update, not insert. However, if the entity has a auto-generated key, and no key value has been set, then the entity is instead automatically marked for insert.

    This behavior was introduced in EF Core 2.0. For earlier releases it is always necessary to explicitly choose either Add or Update.

    If the entity is not using auto-generated keys, then the application must decide whether the entity should be inserted or updated.

    I've tried this out in a test project and can confirm that Update works for both adding and updating an entity in EF Core 2.2, with an auto-generated key.

    The Disconnected entities article linked above also includes sample code for a homemade InsertOrUpdate method, for earlier versions of EF Core or if the entity doesn't have an auto-generated key. The sample code is specific to a particular entity class and would need modification to make it generalized.

    0 讨论(0)
  • 2020-12-16 10:10

    None of the answers worked for me using Entity Framework Core (2.0) so here's the solution that worked for me:

    public static class DbSetExtensions
    {
    
        public static void AddOrUpdate<T>(this DbSet<T> dbSet, Expression<Func<T, object>> identifierExpression, params T[] entities) where T : class
        {
            foreach (var entity in entities)
                AddOrUpdate(dbSet, identifierExpression, entity);
        }
    
    
        public static void AddOrUpdate<T>(this DbSet<T> dbSet, Expression<Func<T, object>> identifierExpression, T entity) where T : class
        {
            if (identifierExpression == null)
                throw new ArgumentNullException(nameof(identifierExpression));
            if (entity == null)
                throw new ArgumentNullException(nameof(entity));
    
            var keyObject = identifierExpression.Compile()(entity);
            var parameter = Expression.Parameter(typeof(T), "p");
    
            var lambda = Expression.Lambda<Func<T, bool>>(
                Expression.Equal(
                    ReplaceParameter(identifierExpression.Body, parameter),
                    Expression.Constant(keyObject)),
                parameter);
    
            var item = dbSet.FirstOrDefault(lambda.Compile());
            if (item == null)
            {
                // easy case
                dbSet.Add(entity);
            }
            else
            {
                // get Key fields, using KeyAttribute if possible otherwise convention
                var dataType = typeof(T);
                var keyFields = dataType.GetProperties().Where(p => p.GetCustomAttribute<KeyAttribute>() != null).ToList();
                if (!keyFields.Any())
                {
                    string idName = dataType.Name + "Id";
                    keyFields = dataType.GetProperties().Where(p => 
                        string.Equals(p.Name, "Id", StringComparison.OrdinalIgnoreCase) || 
                        string.Equals(p.Name, idName, StringComparison.OrdinalIgnoreCase)).ToList();
                }
    
                // update all non key and non collection properties
                foreach (var p in typeof(T).GetProperties().Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
                {
                    // ignore collections
                    if (p.PropertyType != typeof(string) && p.PropertyType.GetInterface(nameof(System.Collections.IEnumerable)) != null)
                        continue;
    
                    // ignore ID fields
                    if (keyFields.Any(x => x.Name == p.Name))
                        continue;
    
                    var existingValue = p.GetValue(entity);
                    if (!Equals(p.GetValue(item), existingValue))
                    {
                        p.SetValue(item, existingValue);
                    }
                }
    
                // also update key values on incoming data item where appropriate
                foreach (var idField in keyFields.Where(p => p.GetSetMethod() != null && p.GetGetMethod() != null))
                {
                    var existingValue = idField.GetValue(item);
                    if (!Equals(idField.GetValue(entity), existingValue))
                    {
                        idField.SetValue(entity, existingValue);
                    }
                }
            }
        }
    
    
        private static Expression ReplaceParameter(Expression oldExpression, ParameterExpression newParameter)
        {
            switch (oldExpression.NodeType)
            {
                case ExpressionType.MemberAccess:
                    var m = (MemberExpression)oldExpression;
                    return Expression.MakeMemberAccess(newParameter, m.Member);
                case ExpressionType.New:
                    var newExpression = (NewExpression)oldExpression;
                    var arguments = new List<Expression>();
                    foreach (var a in newExpression.Arguments)
                        arguments.Add(ReplaceParameter(a, newParameter));
                    var returnValue = Expression.New(newExpression.Constructor, arguments.ToArray());
                    return returnValue;
                default:
                    throw new NotSupportedException("Unknown expression type for AddOrUpdate: " + oldExpression.NodeType);
            }
        }
    }
    

    You may need to update the ReplaceParameter() method if you have a more complicated identifierExpression. Simple property accessors will work fine with this implementation. e.g.:

    context.Projects.AddOrUpdate(x => x.Name, new Project { ... })
    context.Projects.AddOrUpdate(x => new { x.Name, x.Description }, new Project { ... })
    

    Then context.SaveChanges() will commit the data to the database

    0 讨论(0)
  • 2020-12-16 10:10

    Here is my solution based on other solutions from this thread.

    • Supports composite keys
    • Supports shadow-property keys.
    • Stays within EF Core realm and doesn't use reflection.
    • Change _appDb to your context.
    
            public object PrimaryKeyValues<TEntity>(TEntity entity)
            {
                var properties = _appDb.Model.FindEntityType(typeof(TEntity)).FindPrimaryKey().Properties;
    
                var entry = _appDb.Entry(entity);
    
                var values = properties?.Select(p => entry.Property(p.Name).CurrentValue);
    
                if (values?.Count() == 1)
                    return values.Single();
    
                return values?.ToArray();
            }
    
    
            public async Task<TEntity> AddOrUpdateAsync<TEntity>(TEntity entity) where TEntity : class
            {
                var pkValue = PrimaryKeyValues(entity);
    
                if (pkValue == null)
                {
                    throw new Exception($"{typeof(TEntity).FullName} does not have a primary key specified. Unable to exec AddOrUpdateAsync call.");
                }
    
                if ((await _appDb.FindAsync(typeof(TEntity), pkValue)) is TEntity dbEntry)
                {
                    _appDb.Entry(dbEntry).CurrentValues.SetValues(entity);
                    _appDb.Update(dbEntry);
    
                    entity = dbEntry;
                }
                else
                {
                    _appDb.Add(entity);
                }
    
                return entity;
            }
    

    Update - Add Or Update Range

    Complete solution. Doesn't support keys that are shadow properties

    DbContextExtensions.cs

            // FIND ALL
            // ===============================================================
            /// <summary>
            /// Tries to get all entities by their primary keys. Return all/partial/empty array of database entities.
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="dbContext"></param>
            /// <param name="args"></param>
            /// <returns></returns>
            public static async Task<TEntity[]> FindAllAsync<TEntity>(this DbContext dbContext, IEnumerable<TEntity> args) where TEntity : class
            {
                return await Task.Run(() => { 
                    var dbParameter = Expression.Parameter(typeof(TEntity), typeof(TEntity).Name);
    
                    var properties = dbContext.Model.FindEntityType(typeof(TEntity)).FindPrimaryKey()?.Properties;
    
                    if (properties == null)
                        throw new ArgumentException($"{typeof(TEntity).FullName} does not have a primary key specified.");
    
                    if (args == null)
                        throw new ArgumentNullException($"Entities to find argument cannot be null");
    
                    if (!args.Any())
                        return Enumerable.Empty<TEntity>().ToArray();
    
                    var aggregatedExpression = args.Select(entity =>
                    {
                        var entry = dbContext.Entry(entity);
    
                        return properties.Select(p =>
                        {
                            var dbProp = dbParameter.Type.GetProperty(p.Name); 
                            var left = Expression.Property(dbParameter, dbProp); 
    
                            var argValue = entry.Property(p.Name).CurrentValue;
                            var right = Expression.Constant(argValue);
    
                            return Expression.Equal(left, right);
                        })
                        .Aggregate((acc, next) => Expression.And(acc, next));
                    })
                    .Aggregate((acc, next) => Expression.OrElse(acc, next));
    
                    var whereMethod = typeof(Enumerable).GetMethods().First(m => m.Name == "Where" && m.GetParameters().Length == 2);
                    MethodInfo genericWhereMethod = whereMethod.MakeGenericMethod(typeof(TEntity));
    
                    var whereLambda = Expression.Lambda(aggregatedExpression, dbParameter);
    
                    var set = dbContext.Set<TEntity>();
                    var func = whereLambda.Compile();
    
                    var result = genericWhereMethod.Invoke(null, new object[] { set, func}) as IEnumerable<TEntity>;
    
                    return result.ToArray();
                });
            }
    
            // ADD OR UPDATE - RANGE - ASYNC
            // ===============================================================
            /// <summary>
            /// Foreach entity in a range, adds it when it doesn't exist otherwise updates it. Bases decision on Pk.
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="dbContext"></param>
            /// <param name="entities"></param>
            /// <returns></returns>
            public static async Task<(int AddedCount, int UpdatedCount)> AddOrUpdateRangeAsync<TEntity>(this DbContext dbContext, IEnumerable<TEntity> entities) where TEntity : class
            {
                var existingEntitiesHashes = (await dbContext.FindAllAsync(entities)).Select(x =>
                {
                    dbContext.Entry(x).State = EntityState.Detached;
                    return dbContext.PrimaryKeyHash(x);
                });
    
                var (True, False) = entities.DivideOn(x => existingEntitiesHashes.Contains(dbContext.PrimaryKeyHash(x)));
    
                dbContext.UpdateRange(True);
                dbContext.AddRange(False);
    
                return (AddedCount: False.Count(), UpdatedCount: True.Count());
            }
    
    
            // ADD OR UPDATE - ASYNC
            // ===============================================================
            /// <summary>
            /// Adds when not existing otherwise updates an entity. Bases decision on Pk.
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="dbContext"></param>
            /// <param name="entity"></param>
            /// <returns></returns>
            public static async Task AddOrUpdateAsync<TEntity>(this DbContext dbContext, TEntity entity) where TEntity : class
                => await dbContext.AddOrUpdateRangeAsync(new TEntity[] { entity });
    
            // PK HASH
            // ===============================================================
            /// <summary>
            /// Returns the compounded hash string of all primary keys of the entity
            /// </summary>
            /// <typeparam name="TTarget"></typeparam>
            /// <param name="dbContext"></param>
            /// <param name="entity"></param>
            /// <returns></returns>
            public static string PrimaryKeyHash<TTarget>(this DbContext dbContext, TTarget entity)
            {
                var properties = dbContext.Model.FindEntityType(typeof(TTarget)).FindPrimaryKey().Properties;
    
                var entry = dbContext.Entry(entity);
    
                return properties.Select(p => Crypto.HashGUID(entry.Property(p.Name).CurrentValue))
                                 .Aggregate(string.Empty, (acc, next) => acc += next);
            }
    

    Crypto.cs

        public class Crypto
        {
            /// <summary>
            /// RETURNS A HASH AS A GUID BASED ON OBJECT.TOSTRING()
            /// </summary>
            /// <param name="obj"></param>
            /// <returns></returns>
            public static string HashGUID(object obj)
            {
                string text = string.Empty;
                MD5CryptoServiceProvider md5CryptoServiceProvider = new MD5CryptoServiceProvider();
                byte[] bytes = new UTF8Encoding().GetBytes(obj.ToString());
                byte[] array = md5CryptoServiceProvider.ComputeHash(bytes);
                for (int i = 0; i < array.Length; i++)
                {
                    text += Convert.ToString(array[i], 16).PadLeft(2, '0');
                }
                md5CryptoServiceProvider.Clear();
                return text.PadLeft(32, '0');
            }
        }
    

    IEnumerableExtensions.cs

            /// <summary>
            /// Divides into two based on predicate
            /// </summary>
            /// <typeparam name="T"></typeparam>
            /// <param name="source"></param>
            /// <param name="predicate"></param>
            /// <returns></returns>
            public static (IEnumerable<T> True, IEnumerable<T> False) DivideOn<T>(this IEnumerable<T> source, Func<T, bool> predicate)
                => (source.Where(x => predicate(x)), source.Where(x => !predicate(x)));
    

    Comment if use it (ノ◕ヮ◕)ノ✲゚。⋆

    0 讨论(0)
  • 2020-12-16 10:12

    I think this is what you want.

    public static class DbSetExtension
    {
        public static void AddOrUpdate<T>(this DbSet<T> dbSet, T data) where T : class
        {
            var context = dbSet.GetContext();
            var ids = context.Model.FindEntityType(typeof(T)).FindPrimaryKey().Properties.Select(x => x.Name);
    
            var t = typeof(T);
            List<PropertyInfo> keyFields = new List<PropertyInfo>();
    
            foreach (var propt in t.GetProperties())
            {
                var keyAttr = ids.Contains(propt.Name);
                if (keyAttr)
                {
                    keyFields.Add(propt);
                }
            }
            if (keyFields.Count <= 0)
            {
                throw new Exception($"{t.FullName} does not have a KeyAttribute field. Unable to exec AddOrUpdate call.");
            }
            var entities = dbSet.AsNoTracking().ToList();
            foreach (var keyField in keyFields)
            {
                var keyVal = keyField.GetValue(data);
                entities = entities.Where(p => p.GetType().GetProperty(keyField.Name).GetValue(p).Equals(keyVal)).ToList();
            }
            var dbVal = entities.FirstOrDefault();
            if (dbVal != null)
            {
                context.Entry(dbVal).CurrentValues.SetValues(data);
                context.Entry(dbVal).State = EntityState.Modified;
                return;
            }
            dbSet.Add(data);
        }
    
        public static void AddOrUpdate<T>(this DbSet<T> dbSet, Expression<Func<T, object>> key, T data) where T : class
        {
            var context = dbSet.GetContext();
            var ids = context.Model.FindEntityType(typeof(T)).FindPrimaryKey().Properties.Select(x => x.Name);
            var t = typeof(T);
            var keyObject = key.Compile()(data);
            PropertyInfo[] keyFields = keyObject.GetType().GetProperties().Select(p=>t.GetProperty(p.Name)).ToArray();
            if (keyFields == null)
            {
                throw new Exception($"{t.FullName} does not have a KeyAttribute field. Unable to exec AddOrUpdate call.");
            }
            var keyVals = keyFields.Select(p => p.GetValue(data));
            var entities = dbSet.AsNoTracking().ToList();
            int i = 0;
            foreach (var keyVal in keyVals)
            {
                entities = entities.Where(p => p.GetType().GetProperty(keyFields[i].Name).GetValue(p).Equals(keyVal)).ToList();
                i++;
            }
            if (entities.Any())
            {
                var dbVal = entities.FirstOrDefault();
                var keyAttrs =
                    data.GetType().GetProperties().Where(p => ids.Contains(p.Name)).ToList();
                if (keyAttrs.Any())
                {
                    foreach (var keyAttr in keyAttrs)
                    {
                        keyAttr.SetValue(data,
                            dbVal.GetType()
                                .GetProperties()
                                .FirstOrDefault(p => p.Name == keyAttr.Name)
                                .GetValue(dbVal));
                    }
                    context.Entry(dbVal).CurrentValues.SetValues(data);
                    context.Entry(dbVal).State = EntityState.Modified;
                    return;
                }                
            }
            dbSet.Add(data);
        }
    }
    
    public static class HackyDbSetGetContextTrick
    {
        public static DbContext GetContext<TEntity>(this DbSet<TEntity> dbSet)
            where TEntity : class
        {
            return (DbContext)dbSet
                .GetType().GetTypeInfo()
                .GetField("_context", BindingFlags.NonPublic | BindingFlags.Instance)
                .GetValue(dbSet);
        }
    }
    
    0 讨论(0)
提交回复
热议问题