Using DataTable in .NET Core

前端 未结 6 895
你的背包
你的背包 2020-12-14 10:11

I have a stored procedure in SQL Server that accepts a User-Defined Table Type. I\'m following the answer from this post Bulk insert from C# list into SQL Server into multip

6条回答
  •  攒了一身酷
    2020-12-14 11:00

    You can use a DbDataReader as the value of the SQL parameter. So, the idea is to convert an IEnumerable to a DbDataReader.

    public class ObjectDataReader : DbDataReader
    {
        private bool _iteratorOwned;
        private IEnumerator _iterator;
        private IDictionary _propertyNameToOrdinal = new Dictionary();
        private IDictionary _ordinalToPropertyName = new Dictionary();
        private Func[] _getPropertyValueFuncs;
    
        public ObjectDataReader(IEnumerable enumerable)
        {
            if (enumerable == null) throw new ArgumentNullException(nameof(enumerable));
    
            _iteratorOwned = true;
            _iterator = enumerable.GetEnumerator();
            _iterator.MoveNext();
            Initialize();
        }
    
        public ObjectDataReader(IEnumerator iterator)
        {
            if (iterator == null) throw new ArgumentNullException(nameof(iterator));
    
            _iterator = iterator;    
            Initialize();
        }
    
        protected override void Dispose(bool disposing)
        {
            if (disposing && _iteratorOwned)
            {
                if(_iterator != null)
                    _iterator.Dispose();
            }
    
            base.Dispose(disposing);
        }
    
        private void Initialize()
        {
            int ordinal = 0;
            var properties = typeof(T).GetProperties();
            _getPropertyValueFuncs = new Func[properties.Length];
            foreach (var property in properties)
            {
                string propertyName = property.Name;
                _propertyNameToOrdinal.Add(propertyName, ordinal);
                _ordinalToPropertyName.Add(ordinal, propertyName);
    
                var parameterExpression = Expression.Parameter(typeof(T), "x");
                var func = (Func)Expression.Lambda(Expression.Convert(Expression.Property(parameterExpression, propertyName), typeof(object)), parameterExpression).Compile();
                _getPropertyValueFuncs[ordinal] = func;
    
                ordinal++;
            }
        }
    
        public override object this[int ordinal] 
        {
            get
            {
                return GetValue(ordinal);
            }
        }
    
        public override object this[string name]
        {
            get
            {
                return GetValue(GetOrdinal(name));
            }
        }
    
        public override int Depth => 1;
    
        public override int FieldCount => _ordinalToPropertyName.Count;
    
        public override bool HasRows => true;
    
        public override bool IsClosed
        {
            get
            {
                return _iterator != null;
            }
        }
    
        public override int RecordsAffected
        {
            get
            {
                throw new NotImplementedException();
            }
        }
    
        public override bool GetBoolean(int ordinal)
        {
            return (bool)GetValue(ordinal);
        }
    
        public override byte GetByte(int ordinal)
        {
            return (byte)GetValue(ordinal);
        }
    
        public override long GetBytes(int ordinal, long dataOffset, byte[] buffer, int bufferOffset, int length)
        {
            throw new NotImplementedException();
        }
    
        public override char GetChar(int ordinal)
        {
            return (char)GetValue(ordinal);
        }
    
        public override long GetChars(int ordinal, long dataOffset, char[] buffer, int bufferOffset, int length)
        {
            throw new NotImplementedException();
        }
    
        public override string GetDataTypeName(int ordinal)
        {
            throw new NotImplementedException();
        }
    
        public override DateTime GetDateTime(int ordinal)
        {
            return (DateTime)GetValue(ordinal);
        }
    
        public override decimal GetDecimal(int ordinal)
        {
            return (decimal)GetValue(ordinal);
        }
    
        public override double GetDouble(int ordinal)
        {
            return (double)GetValue(ordinal);
        }
    
        public override IEnumerator GetEnumerator()
        {
            throw new NotImplementedException();
        }
    
        public override Type GetFieldType(int ordinal)
        {
            var value = GetValue(ordinal);
            if (value == null)
                return typeof(object);
    
            return value.GetType();
        }
    
        public override float GetFloat(int ordinal)
        {
            return (float)GetValue(ordinal);
        }
    
        public override Guid GetGuid(int ordinal)
        {
            return (Guid)GetValue(ordinal);
        }
    
        public override short GetInt16(int ordinal)
        {
            return (short)GetValue(ordinal);
        }
    
        public override int GetInt32(int ordinal)
        {
            return (int)GetValue(ordinal);
        }
    
        public override long GetInt64(int ordinal)
        {
            return (long)GetValue(ordinal);
        }
    
        public override string GetName(int ordinal)
        {
            string name;
            if (_ordinalToPropertyName.TryGetValue(ordinal, out name))
                return name;
    
            return null;
        }
    
        public override int GetOrdinal(string name)
        {
            int ordinal;
            if (_propertyNameToOrdinal.TryGetValue(name, out ordinal))
                return ordinal;
    
            return -1;
        }
    
        public override string GetString(int ordinal)
        {
            return (string)GetValue(ordinal);
        }
    
        public override object GetValue(int ordinal)
        {
            var func = _getPropertyValueFuncs[ordinal];
            return func(_iterator.Current);
        }
    
        public override int GetValues(object[] values)
        {
            int max = Math.Min(values.Length, FieldCount);
            for (var i = 0; i < max; i++)
            {
                values[i] = IsDBNull(i) ? DBNull.Value : GetValue(i);
            }
    
            return max;
        }
    
        public override bool IsDBNull(int ordinal)
        {
            return GetValue(ordinal) == null;
        }
    
        public override bool NextResult()
        {
            return false;
        }
    
        public override bool Read()
        {
            return _iterator.MoveNext();
        }
    }
    

    Then, you can use this class:

    static void Main(string[] args)
    {
        Console.WriteLine("Hello World!");
        string connectionString = "Server=(local);Database=Sample;Trusted_Connection=True;";
    
        using (var connection = new SqlConnection(connectionString))
        {
            connection.Open();
    
            using (var command = connection.CreateCommand())
            {
                command.CommandType = System.Data.CommandType.StoredProcedure;
                command.CommandText = "procMergePageView";
    
                var p1 = command.CreateParameter();
                command.Parameters.Add(p1);    
                p1.ParameterName = "@Display";
                p1.SqlDbType = System.Data.SqlDbType.Structured;
                var items = PageViewTableType.Generate(100);
                using (DbDataReader dr = new ObjectDataReader(items))
                {
                    p1.Value = dr;
                    command.ExecuteNonQuery();
                }    
            }
        }
    }
    
    class PageViewTableType
    {
        // Must match the name of the column of the TVP
        public long PageViewID { get; set; }
    
        // Generate dummy data
        public static IEnumerable Generate(int count)
        {
            for (int i = 0; i < count; i++)
            {
                yield return new PageViewTableType { PageViewID = i };
            }
        }
    }
    

    The SQL scripts:

    CREATE TABLE dbo.PageView
    (
        PageViewID BIGINT NOT NULL CONSTRAINT pkPageView PRIMARY KEY CLUSTERED,
        PageViewCount BIGINT NOT NULL
    );
    GO
    
    CREATE TYPE dbo.PageViewTableType AS TABLE
    (
        PageViewID BIGINT NOT NULL
    );
    GO
    
    CREATE PROCEDURE dbo.procMergePageView
        @Display dbo.PageViewTableType READONLY
    AS
    BEGIN
        MERGE INTO dbo.PageView AS T
        USING @Display AS S
        ON T.PageViewID = S.PageViewID
        WHEN MATCHED THEN UPDATE SET T.PageViewCount = T.PageViewCount + 1
        WHEN NOT MATCHED THEN INSERT VALUES(S.PageViewID, 1);
    END
    

    By the way, I've written a blog post about the ObjectDataReader

提交回复
热议问题