model.summary() can't print output shape while using subclass model

前端 未结 4 910
情话喂你
情话喂你 2021-01-01 20:48

This is the two methods for creating a keras model, but the output shapes of the summary results of the two methods are different. Obviously, the former prints

4条回答
  •  Happy的楠姐
    2021-01-01 21:34

    The way I solve the problem is very similar to what Elazar mensioned. Override the function summary() in the class subclass. Then you can directly call summary() while using model subclassing:

    class subclass(Model):
        def __init__(self):
            ...
        def call(self, x):
            ...
    
        def summary(self):
            x = Input(shape=(24, 24, 3))
            model = Model(inputs=[x], outputs=self.call(x))
            return model.summary()
    
    if __name__ == '__main__':
        sub = subclass()
        sub.summary()
    

提交回复
热议问题