Calculate/Visualize Tensorflow Keras Dense model layer relative connection weights w.r.t output classes

倾然丶 夕夏残阳落幕 提交于 2019-12-06 14:54:51

I had come across a similar problem, but i was concerned more about the visualization of the model parameters(weights and biases) rather than the model features [since i wanted to explore and view the black box as well].

For example, the following is a snippet of a shallow neural network with 2 hidden layers.

model = Sequential()
model.add(Dense(128, input_dim=13, kernel_initializer='uniform', activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(64, kernel_initializer='uniform', activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(64, kernel_initializer='uniform', activation='relu'))
model.add(Dropout(0.1))
model.add(Dense(8, kernel_initializer='uniform', activation='softmax'))

# Compile model
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Using TensorBoard to visualise the Model
ks=TensorBoard(log_dir="/your_full_path/logs/{}".format(time()), histogram_freq=1, write_graph=True, write_grads=True, batch_size=10)

# Fit the model   
model.fit(X, Y, epochs = 64, shuffle = True, batch_size=10, verbose = 2, validation_split=0.2, callbacks=[ks])

For one to be able to visualize the parameters, there are few important things to be kept in mind:

  1. Always ensure to have a validation_split in the model.fit() function[else histograms cannot be visualised].

  2. Make sure the value of histogram_freq > 0 always!![histograms won't be computed otherwise].

  3. The callbacks to TensorBoard have to be specified in the model.fit() as a list.

Once, this is done; goto cmd and type the folllowing command:

tensorboard --logdir=logs/

This gives you a local address with which you can access TensorBoard on your web browser. All histograms, distributions, loss and accuracy functions will be available as plots and can be selected from the menu bar at the top.

Hope this answer gives a hint about the procedure to visualize model parameters(I myself had a bit of a struggle as the above points weren't available together).

Do let me know if it helped.

Below is the keras documentation link for your reference:

https://keras.io/callbacks/#tensorboard

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!