How does one pickle arbitrary pytorch models that use lambda functions?

邮差的信 提交于 2020-05-17 06:17:47

问题


I currently have a neural network module:

import torch.nn as nn

class NN(nn.Module):
    def __init__(self,args,lambda_f,nn1, loss, opt):
        super().__init__()
        self.args = args
        self.lambda_f = lambda_f
        self.nn1 = nn1
        self.loss = loss
        self.opt = opt
        # more nn.Params stuff etc...

    def forward(self, x):
        #some code using fields
        return out

I am trying to checkpoint it but because pytorch saves using state_dicts it means I can't save the lambda functions I was actually using if I checkpoint with the pytorch torch.save etc. I literally want to save everything without issue and re-load to train on GPUs later. I currently am using this:

def save_ckpt(path_to_ckpt):
    from pathlib import Path
    import dill as pickle
    ## Make dir. Throw no exceptions if it already exists
    path_to_ckpt.mkdir(parents=True, exist_ok=True)
    ckpt_path_plus_path = path_to_ckpt / Path('db')

    ## Pickle args
    db['crazy_mdl'] = crazy_mdl
    with open(ckpt_path_plus_path , 'ab') as db_file:
        pickle.dump(db, db_file)

currently it throws no errors when I chekpoint it and it saved it.

I am worried that when I train it there might be a subtle bug even if no exceptions/errors are trained or something unexpected might happen (e.g. weird saving on disks in the clusters etc who knows).

Is this safe to do with pytorch classes/nn models? Especially if we want to resume training with GPUs?

Cross posted:

  • How does one pickle arbitrary pytorch models that use lambda functions?
  • https://discuss.pytorch.org/t/how-does-one-pickle-arbitrary-pytorch-models-that-use-lambda-functions/79026
  • https://www.reddit.com/r/pytorch/comments/gagpjg/how_does_one_pickle_arbitrary_pytorch_models_that/?
  • https://www.quora.com/unanswered/How-does-one-pickle-arbitrary-PyTorch-models-that-use-lambda-functions

回答1:


I'm the dill author. I use dill (and klepto) to save classes that contain trained ANNs inside of lambda functions. I tend to use combinations of mystic and sklearn, so I can't speak directly to pytorch, but I can assume it works the same. The place where you have to be careful is if you have a lambda that contains a pointer to an object external to the lambda... so for example y = 4; f = lambda x: x+y. This might seem obvious, but dill will pickle the lambda, and depending on the rest of the code and the serialization variant, may not serialize the value of y. So, I've seen many cases where people serialize a trained estimator inside some function (or lambda, or class) and then the results aren't "correct" when they restore the function from serialization. The overarching cause is because the function wasn't encapsulated so all objects required for the function to yield the correct results are stored in the pickle. However, even in that case you can get the "correct" results back, but you'd just need to create the same environment you had when you pickled the estimator (i.e. all the same values it depends on in the surrounding namespace). The takeaway should be, try to make sure that all variables used in the function are defined within the function. Here's a portion of a class I've recently started to use myself (should be in the next release of mystic):

class Estimator(object):
    "a container for a trained estimator and transform (not a pipeline)"
    def __init__(self, estimator, transform):
        """a container for a trained estimator and transform

    Input:
        estimator: a fitted sklearn estimator
        transform: a fitted sklearn transform
        """
        self.estimator = estimator
        self.transform = transform
        self.function = lambda *x: float(self.estimator.predict(self.transform.transform(np.array(x).reshape(1,-1))).reshape(-1))
    def __call__(self, *x):
        "f(*x) for x of xtest and predict on fitted estimator(transform(xtest))"
        import numpy as np
        return self.function(*x)

Note when the function is called, everything that it uses (including np) is defined in the surrounding namespace. As long as pytorch estimators serialize as expected (without external references), then you should be fine if you follow the above guidelines.




回答2:


Yes, I think it is safe to use dill to pickle lambda functions etc. I have been using torch.save with dill to save state dict and have had no problems resuming training over GPU as well as CPU unless the model class was changed. Even if the model class was changed (adding/deleting some parameters), I could load state dict, modify it, and load to the model.

Also, usually, people don't save the model objects but only state dicts i.e parameter values to resume the training along with hyperparameters/model arguments to get the same model object later.

Saving model object can be sometimes problematic as changes to model class (code) can make the saved object useless. If you don't plan on changing your model class/code at all and hence the model object won't be changed then maybe saving objects can work well but generally, it is not recommended to pickle module object.



来源:https://stackoverflow.com/questions/61510810/how-does-one-pickle-arbitrary-pytorch-models-that-use-lambda-functions

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!