How to avoid OOM errors in repeated training and prediction in TensorFlow?

梦想与她 提交于 2021-02-05 09:39:07

问题


I have some code in TensorFlow which takes a base model, fine-tunes (trains) it with some data, and then uses the model to predict() using some other data. All this is encapsulated in a main() method of a module and works fine.

When I run this code in a loop over different base models, however, I end up with an OOM after, e.g., 7 base models. Is this expected? I would expect that Python cleans up after each main() call. Does TensorFlow not do that? How can I force it to?

Edit: here's an MWE showing not the OOM crashes, but increasing memory consumption:

import gc
import os

import numpy as np
import psutil
import tensorflow as tf

tf.get_logger().setLevel("ERROR")  # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
    (model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
    history = model.fit(
        x=(x := tf.zeros((1, *model.input.shape[1:]))),
        y=(y := tf.zeros((1, *model.output.shape[1:]))),
        verbose=0,
    )
    prediction = model.predict(x)
    _ = gc.collect()
    # tf.keras.backend.clear_session()
    print(f"rss {i}: {process.memory_info().rss >> 20} MB")

On my computer (CPU), it prints

rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...

With tf.keras.backend.clear_session() uncommented, it's better, but not perfect yet:

rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB

Switching the order of gc.collect() and tf.keras.backend.clear_session() did not help, either.

来源:https://stackoverflow.com/questions/63411142/how-to-avoid-oom-errors-in-repeated-training-and-prediction-in-tensorflow

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!