我们经常使用到的三个回调函数为:
TensorBoard
ModelCheckpoint
EarlyStopping
可以这样使用:
logdir = "./callback"
if not os.path.exists(logdir):
os.mkdir(logdir)
out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
callbacks=[
k.callbacks.TensorBoard(logdir),
k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True),
k.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
]
history=model.fit(x_train,y_train,epochs=10,
validation_data=(x_valid,y_valid),
callbacks=callbacks)
完整代码:

#!/usr/bin/env python
# coding: utf-8
# In[2]:
import tensorflow as tf
import tensorflow.keras as k
import numpy as np
import matplotlib.pyplot as plt
import os
# In[3]:
fashion_mnist = k.datasets.fashion_mnist
(x_train,y_train),(x_test,y_test)=fashion_mnist.load_data()
x_train,x_valid = x_train[:5000],x_train[5000:]
y_train,y_valid= y_train[:5000],y_train[5000:]
# In[4]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
# In[5]:
#build the model
model =k.Sequential()
model.add(k.layers.Flatten(input_shape=[28,28]))
model.add(k.layers.Dense(300,activation="relu"))
model.add(k.layers.Dense(100,activation="relu"))
model.add(k.layers.Dense(10,activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
# In[7]:
logdir = "./callback"
if not os.path.exists(logdir):
os.mkdir(logdir)
out_put_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
callbacks=[
k.callbacks.TensorBoard(logdir),
k.callbacks.ModelCheckpoint(out_put_model_file,save_best_only=True),
k.callbacks.EarlyStopping(patience=5,min_delta=1e-3),
]
history=model.fit(x_train,y_train,epochs=10,
validation_data=(x_valid,y_valid),
callbacks=callbacks)
# In[ ]:
import pandas as pd
def plot_curve(history):
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.show()
plot_curve(history)
# In[ ]:
来源:https://www.cnblogs.com/superxuezhazha/p/12363521.html
