批量处理NTU数据集

时光总嘲笑我的痴心妄想 提交于 2020-03-06 00:07:03
'''
读取NTU原始数据的x y z 坐标,
保存在字典中
并将文件保存
'''
import os
import numpy as np
import pickle


def get_raw_body_data(ske_file,ske_name):


     assert os.path.exists(ske_file),'Error:Skeleton file {} not fond'.format(ske_file)
     print("Reading data from {}".format(ske_file))
     with open(ske_file,'r') as fr :
          str_data = fr.readlines()
     num_frames = int(str_data[0].strip('\r\n'))

     print("此文件共:{}帧".format(num_frames))
     frame_drop = []
     bodies_data = dict()
     valid_frames = -1
     current_line = 1

     for f in range(num_frames):
          num_bodies = int(str_data[current_line].strip('\r\n'))
          current_line += 1
          print("正在读取第{}帧".format(f))
          if num_frames == 0 :
               frame_drop.append(f)
               continue

          valid_frames += 1
          joints = np.zeros((num_bodies, 25, 3), dtype = np.float32)

          # colors = np.zeros((num_bodies, 25, 2), dtype = np.float32)

          for b in range(num_bodies):
               bodyID = str_data[current_line].strip('\r\n').split()[0]
               #bodyID: 72057594037931101
               current_line += 1
               num_joints = int(str_data[current_line].strip('\r\n'))
               current_line += 1

               for j in range(num_joints):
                    temp_str = str_data[current_line].strip('\r\n').split()
                    joints[b, j, :] = np.array(temp_str[ :3], dtype = np.float32)
                    # colors[b, j, :] = np.array(temp_str[5:7], dtype=np.float32)
                    current_line += 1




               if bodyID not in bodies_data:
                    body_data = dict()
                    body_data['joints'] = joints[b]  #(25,3)

                    # body_data['colors'] = colors[b, np.newaxis] #(1,25,2)
                    # body_data['interval'] = [valid_frames]
               else :
                    body_data = bodies_data[bodyID]

                    body_data['joints'] = np.vstack((body_data['joints'],joints[b]))
                    # pre_frame_idx = body_data['interval'][-1]
                    # body_data['interval'].append(pre_frame_idx + 1)

               bodies_data[bodyID] = body_data



     return {'name':ske_name,'data':bodies_data }

if __name__ == '__main__':
     skes_path = './/skeleton'
     skes_new_path = './/skeleton_new'
     name_list = os.listdir(skes_path)

     for filename in  name_list:
          path1 = os.path.join(skes_path,filename)
          path2 = os.path.join(skes_new_path,filename)
          print("正在处理{}文件".format(filename))
          raw_skes_data = get_raw_body_data(path1,filename)

          # print('raw_skes_data',raw_skes_data)
          # print("***",raw_skes_data['data']['72057594037931115']['joints'].shape)

          with open(path2, 'wb') as fw:

               pickle.dump(raw_skes_data, fw, pickle.HIGHEST_PROTOCOL)

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!