Can Numba be used to compile Python code which interfaces with Tensorflow? I.e. computations outside of the Tensorflow universe would run with Numba for speed. I have not fo
I know that this does not directly answer you question, but it might be a good alternative. Numba is using just-in-time (JIT) Compilation. So, you can follow the instruction at the official TensorFlow documentation here on how to use JIT (but not in Numba ecosystem) in TensorFlow.
You can use tf.numpy_function, or tf.py_func to wrap a python function and use it as a TensorFlow op. Here is an example which I used:
@jit
def dice_coeff_nb(y_true, y_pred):
"Calculates dice coefficient"
smooth = np.float32(1)
y_true_f = np.reshape(y_true, [-1])
y_pred_f = np.reshape(y_pred, [-1])
intersection = np.sum(y_true_f * y_pred_f)
score = (2. * intersection + smooth) / (np.sum(y_true_f) +
np.sum(y_pred_f) + smooth)
return score
@jit
def dice_loss_nb(y_true, y_pred):
"Calculates dice loss"
loss = 1 - dice_coeff_nb(y_true, y_pred)
return loss
def bce_dice_loss_nb(y_true, y_pred):
"Adds dice_loss to categorical_crossentropy"
loss = tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
tf.keras.losses.categorical_crossentropy(y_true, y_pred)
return loss
Then I used this loss function in training a tf.keras model:
...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)