Custom loss function without using keras backend library

。_饼干妹妹 提交于 2019-12-02 00:21:50

If I understood the question you want to be able to generate the loss based on code that you run when the model evaluates the loss function.

This would be an example:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

FACTORS = np.array([[0.5, 2.0, 4.0]])

def ext_function(inputs):
  """ This can be an arbitrary python function of the inputs
  inputs is a tf.EagerTensor which can be converted into a numpy array.
  """
  r = np.dot(inputs, FACTORS.T)
  return r

class LossFunction(object):
  def __init__(self, model):
    # Use model to obtain the inputs
    self.model = model

  def __call__(self, y_true, y_pred, sample_weight=None):
    """ ignore y_true value from fit params and compute it instead using
    ext_function
    """
    y_true = tf.py_function(ext_function, [self.model.inputs[0]], Tout=tf.float32)
    v = keras.losses.mean_squared_error(y_true, y_pred)
    return K.mean(v)

def make_model():
  inp = Input(shape=(3,))
  out = Dense(1, use_bias=False)(inp)
  model = Model(inp, out)
  model.compile('adam', LossFunction(model))
  return model

model = make_model()
model.summary()

Test:

import numpy as np


N_SAMPLES=100
X = np.random.rand(N_SAMPLES, 3)
Y_dummy = np.random.rand(N_SAMPLES)

history = model.fit(X, Y_dummy, epochs=1000, verbose=False)
print(history.history['loss'][-1])

And it actually does something:

model.layers[1].get_weights()

Please note that it will be much simpler to simply generate the correct value of Y as input. I don't know exactly the conditions of your problem. But if at all possible try to pre-generate Y. Rather than use the example above.

I've used the trick above to create custom metrics which are weighted by the class. i.e. in scenarios where one of the input params is a class and the desired loss function is an weighted per class average of the losses.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!