之前做项目,好奇这个函数是怎么实现的,我把源码看了一遍,魔改代码,把没用的删除,重新封装为一个类,还加上了可以输出至txt的功能
'''
class print_summary_magic_modification:
def init(self, model, file_path):
self.model = model
self.file_path = file_path
def params_nums(weights): return int(np.sum([K.count_params(p) for p in set(weights)])) def print_row(self, fields, positions): line = '' for i in range(len(fields)): if i > 0: line = line[:-1] + ' ' line += str(fields[i]) line = line[:positions[i]] line += ' ' * (positions[i] - len(line)) print(line) def print_layer_summary(self, layer, positions): try: output_shape = layer.output_shape except AttributeError: output_shape = 'multiple' name = layer.name cls_name = layer.__class__.__name__ fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] self.print_row(fields, positions) def print_summary(self): """Prints a summary of a model. """ line_length = 65 positions = [29, 55, 100] # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #'] print('_' * line_length) self.print_row(to_display, positions) print('=' * line_length) layers = self.model.layers for i in range(len(layers)): self.print_layer_summary(layers[i], positions) if i == len(layers) - 1: print('=' * line_length) else: print('_' * line_length) def print_summary2txt(self): """Prints a summary of a model. """ with open(self.file_path, 'a', encoding='utf-8') as f: line_length = 65 positions = [29, 55, 100] # header names for the different log elements to_display = ['Layer (type)', 'Output Shape', 'Param #'] print('_' * line_length) self.print_row(to_display, positions) print('=' * line_length) layers = self.model.layers for i in range(len(layers)): self.print_layer_summary(layers[i], positions) if i == len(layers) - 1: print('=' * line_length) else: print('_' * line_length)
'''
下面这个功能可以直接使用model.summary()输出至txt文件,我在google中搜了好久找见的pythonic代码
'''
output_file_path = ''
file_name = ''
with open(output_file_path + file_name, 'w') as f:
with redirect_stdout(f):
model.summary()
'''