Can Numba be used with Tensorflow?

后端 未结 2 740
梦如初夏
梦如初夏 2020-12-31 10:24

Can Numba be used to compile Python code which interfaces with Tensorflow? I.e. computations outside of the Tensorflow universe would run with Numba for speed. I have not fo

相关标签:
2条回答
  • 2020-12-31 10:48

    I know that this does not directly answer you question, but it might be a good alternative. Numba is using just-in-time (JIT) Compilation. So, you can follow the instruction at the official TensorFlow documentation here on how to use JIT (but not in Numba ecosystem) in TensorFlow.

    0 讨论(0)
  • 2020-12-31 10:59

    You can use tf.numpy_function, or tf.py_func to wrap a python function and use it as a TensorFlow op. Here is an example which I used:

    @jit
    def dice_coeff_nb(y_true, y_pred):
        "Calculates dice coefficient"
        smooth = np.float32(1)
        y_true_f = np.reshape(y_true, [-1])
        y_pred_f = np.reshape(y_pred, [-1])
        intersection = np.sum(y_true_f * y_pred_f)
        score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                                np.sum(y_pred_f) + smooth)
        return score
    
    @jit
    def dice_loss_nb(y_true, y_pred):
        "Calculates dice loss"
        loss = 1 - dice_coeff_nb(y_true, y_pred)
        return loss
    
    def bce_dice_loss_nb(y_true, y_pred):
        "Adds dice_loss to categorical_crossentropy"
        loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
                tf.keras.losses.categorical_crossentropy(y_true, y_pred)
        return loss
    

    Then I used this loss function in training a tf.keras model:

    ...
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='adam', loss=bce_dice_loss_nb)
    
    0 讨论(0)
提交回复
热议问题