简单线性回归(sklearn + tensorflow)

放肆的年华 提交于 2020-03-03 23:14:44

概述

最近学习机器学习(和深度学习),入门第一个接触的便是简单线性回归。
所谓线性回归,是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。其形式可表示为:y = w1x1 + w2x2 + w3x3 + ... + w_nx_n + b
而简单线性回归,是其最简单的形式:y = wx + b,即我们所熟知的一次函数,理解为给定权重w和偏置(或称为截距)b,结果y随变量x的变化而变化。

简单线性回归

机器学习中的简单线性回归,个人理解为给定一系列的x值和对应的y值,来确定权重w和偏置b的合理值(即根据某种方法,从数据中找规律)。
具体步骤可大致分为:

  1. 给定一个w和b的初始值
  2. 定义线性回归函数(即一次函数):y = wx + b
  3. 定义损失函数:n条数据,对每一次得到的y值和实际已知的y_true值相减后求平方,然后求和,再求平均值(所谓最小二乘法),用式子可表示为:loss = ( ∑( y - y_true )² ) / n
  4. 训练求损失最小值(即:loss最小时,预测的结果与真实值最接近):入门先不考虑太多,直接使用梯度下降法(一种可以自动更改w和b,使loss函数结果最小的方法)

示例一

以给定的每天来咖啡店人数(个)和咖啡店销售额(元)数据,来实现机器学习中的简单线性回归。

准备数据,并使用散点图显示

简单线性回归(sklearn + tensorflow)

使用sklearn训练简单线性回归模型并进行预测

简单线性回归(sklearn + tensorflow)

示例二

使用tensorflow训练简单线性回归模型

import tensorflow as tf
LEARNING_RATE = 0.1

#创建100条身高与体重数据,假设身高与体重真实关系为:y_real=0.7x + 0.8
with tf.variable_scope( 'Input_Data' ):
    x_data = tf.random_normal( [100, 1], mean=1.75, stddev=0.5, name='x_data' )
    y_real = tf.matmul( x_data, [[0.7]] ) + 0.8

#创建模型:定义权重、偏置和预测结果表达式
with tf.variable_scope( 'Model' ):
    weight = tf.Variable( tf.random_normal( [1, 1], mean=0.0, stddev=0.1 ), name='weight' )
    bias = tf.Variable( 0.0, name='bias' )
    y_predict = tf.matmul( x_data, weight ) + bias

#定义损失函数和优化器(使用梯度下降法)
with tf.variable_scope( 'Optimizer' ):
    loss = tf.reduce_mean( tf.square( y_predict - y_real ) )
    train_op = tf.train.GradientDescentOptimizer( LEARNING_RATE ).minimize( loss )

#创建tensorflow会话,训练模型
with tf.Session() as sess:
    #初始化变量
    init_op = tf.global_variables_initializer()
    sess.run( init_op )

    #写events文件,可使用tensorboard查看
    tf.summary.FileWriter( 'D:/Mine.py/', sess.graph ) 

    #进行训练
    for step in range( 1, 1001 ):
        sess.run( train_op )
        if step % 100 == 0:
            print( 'after{}step, weight={}, bias={}, loss={}'
                  .format( step, sess.run( weight ), sess.run( bias )
                  , sess.run( loss ) ) )

训练结果如下:
简单线性回归(sklearn + tensorflow)

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!