只改了版本问题和我上个博客写的msr_error()函数,其余代码来自于下边这个博客:https://blog.csdn.net/baixiaozhe/article/details/54410313,增加了tf.reset_default_graph()。
上边的链接增加了波士顿房价数据的读取和预处理, 比莫烦tensorflow LSTM/RNN的例子更全面,更有价值。代码亲测,可运行。
from sklearn.datasets import load_boston
from sklearn import preprocessing
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 波士顿房价数据
boston = load_boston()
x = boston.data
y = boston.target
print('波士顿数据X:',x.shape)# (506, 13)
# print(x[::100])
print('波士顿房价Y:',y.shape)
# print(y[::100])
# 数据标准化
ss_x = preprocessing.StandardScaler()
train_x = ss_x.fit_transform(x)
ss_y = preprocessing.StandardScaler()
train_y = ss_y.fit_transform(y.reshape(-1, 1))
BATCH_SIZE = 30
def get_batch_boston():
def get_batch():
class LSTMRNN(object):
if __name__ == '__main__':