问题
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
np.random.seed(4213)
data = np.random.randint(low=1,high=29, size=(500, 160, 160, 10))
labels = np.random.randint(low=0,high=5, size=(500, 160, 160))
nclass = len(np.unique(labels))
print (nclass)
samples, width, height, nbands = data.shape
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.25, random_state=421)
print (X_train.shape)
print (y_train.shape)
arch = tf.keras.applications.VGG16(input_shape=[width, height, nbands],
include_top=False,
weights=None)
model = tf.keras.Sequential()
model.add(arch)
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(nclass))
model.compile(optimizer = tf.keras.optimizers.Adam(0.0001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
model.fit(X_train,
y_train,
epochs=3,
batch_size=32,
verbose=2)
res = model.predict(X_test)
print(res.shape)
When running the above code for semantic segmentation
I get Exception has occurred:
InvalidArgumentError
Incompatible shapes: [32,160,160] vs. [32]
[[node Equal (defined at c...:38) ]] [Op:__inference_train_function_1815]
tensorflow.python.framework.errors_impl.InvalidArgumentError
Here is tensorflow's reference on tensorflow. https://www.tensorflow.org/tutorials/images/segmentation
回答1:
Your issue comes from the definition of labels. You should assign one single label to each image. If you have 500 images, you should have only 500 labels (to avoid these mistakes it is always desirable to use python constants for N_IMAGES
, WIDTH
, HEIGHT
, N_CHANNELS
and N_CLASSES
). Try switching labels
:
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
np.random.seed(4213)
N_IMAGES, WIDTH, HEIGHT, N_CHANNELS = (500, 160, 160, 10)
N_CLASSES = 5
data = np.random.randint(low=1,high=29, size=(N_IMAGES, WIDTH, HEIGHT, N_CHANNELS))
labels = np.random.randint(low=0,high=N_CLASSES, size=(N_IMAGES))
#...
来源:https://stackoverflow.com/questions/63129349/how-to-solve-tensorflow-python-framework-errors-impl-invalidargumenterror