关于Tfrecord

匿名 (未验证) 提交于 2019-12-02 23:59:01

写入Tfrecord

        print("convert data into tfrecord:train\n")         out_file_train = "/home/huadong.wang/bo.yan/fudan_mtl/data/ace2005/bn_nw.train.tfrecord"         writer = tf.python_io.TFRecordWriter(out_file_train)          for i in tqdm(range(len(data_train))):             record = tf.train.Example(features=tf.train.Features(feature={                 'word_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_x[i].tostring()])),                 'et_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et1[i].tostring()])),                 'et_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_et2[i].tostring()])),                 'position_ids1': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),                 'position_ids2': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_p1[i].tostring()])),                 'chunks': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_chunks[i].tostring()])),                 'spath_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_spath[i].tostring()])),                 'seq_len': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_x_len[i]])),                 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.argmax(train_relation[i])])),                 'task': tf.train.Feature(int64_list=tf.train.Int64List(value=[np.int64(0)]))             }))             writer.write(record.SerializeToString())         writer.close() 

  

解析tfrecord

def _parse_tfexample(serialized_example):   '''parse serialized tf.train.SequenceExample to tensors   context features : label, task   sequence features: sentence   '''   context_features={'label'    : tf.FixedLenFeature([], tf.int64),                     'task'    : tf.FixedLenFeature([], tf.int64),                     'seq_len': tf.FixedLenFeature([], tf.int64)}   sequence_features={'word_ids': tf.FixedLenSequenceFeature([], tf.int64),                      'et_ids1': tf.FixedLenSequenceFeature([], tf.int64),                      'et_ids2': tf.FixedLenSequenceFeature([], tf.int64),                      'position_ids1': tf.FixedLenSequenceFeature([], tf.int64),                      'position_ids2': tf.FixedLenSequenceFeature([], tf.int64),                      'chunks': tf.FixedLenSequenceFeature([], tf.int64),                      'spath_ids': tf.FixedLenSequenceFeature([], tf.int64),                      }   context_dict, sequence_dict = tf.parse_single_sequence_example(                       serialized_example,                       context_features   = context_features,                       sequence_features  = sequence_features)    sentence = (sequence_dict['word_ids'],sequence_dict['et_ids1'],sequence_dict['et_ids2'],sequence_dict['position_ids1'],               sequence_dict['position_ids2'],sequence_dict['chunks'],sequence_dict['spath_ids'], context_dict['seq_len'])    label = context_dict['label']   task = context_dict['task']    return task, label, sentence    def read_tfrecord(epoch, batch_size):   for dataset in DATASETS:     train_record_file = os.path.join(OUT_DIR, dataset+'.train.tfrecord')     test_record_file = os.path.join(OUT_DIR, dataset+'.test.tfrecord')      train_data = util.read_tfrecord(train_record_file,                                      epoch,                                      batch_size,                                      _parse_tfexample,                                      shuffle=True)      test_data = util.read_tfrecord(test_record_file,                                      epoch,                                    batch_size,                                     _parse_tfexample,                                      shuffle=False)     yield train_data, test_data 

模型中使用:

  def build_task_graph(self, data):     task_label, labels, sentence = data     # sentence = tf.nn.embedding_lookup(self.word_embed, sentence) ##########################     word_ids, et_ids1,et_ids2,position_ids1,position_ids2,chunks,spath_ids,seq_len = sentence     # sentence = word_ids #########################      self.word_ids = word_ids     self.position_ids1 = position_ids1     self.position_ids2 = position_ids2     self.et_ids1 = et_ids1     self.et_ids2 = et_ids2     self.chunks_ids = chunks     self.spath_ids = spath_ids     self.seq_len = seq_len      sentence = self.add_embedding_layers() 

  

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!