-
Notifications
You must be signed in to change notification settings - Fork 265
/
data_util.py
170 lines (140 loc) · 6.1 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import os.path as osp
import shutil
import time
import datetime
import torch
from util.slconfig import SLConfig
class Error(OSError):
pass
def slcopytree(src, dst, symlinks=False, ignore=None, copy_function=shutil.copyfile,
ignore_dangling_symlinks=False):
"""
modified from shutil.copytree without copystat.
Recursively copy a directory tree.
The destination directory must not already exist.
If exception(s) occur, an Error is raised with a list of reasons.
If the optional symlinks flag is true, symbolic links in the
source tree result in symbolic links in the destination tree; if
it is false, the contents of the files pointed to by symbolic
links are copied. If the file pointed by the symlink doesn't
exist, an exception will be added in the list of errors raised in
an Error exception at the end of the copy process.
You can set the optional ignore_dangling_symlinks flag to true if you
want to silence this exception. Notice that this has no effect on
platforms that don't support os.symlink.
The optional ignore argument is a callable. If given, it
is called with the `src` parameter, which is the directory
being visited by copytree(), and `names` which is the list of
`src` contents, as returned by os.listdir():
callable(src, names) -> ignored_names
Since copytree() is called recursively, the callable will be
called once for each directory that is copied. It returns a
list of names relative to the `src` directory that should
not be copied.
The optional copy_function argument is a callable that will be used
to copy each file. It will be called with the source path and the
destination path as arguments. By default, copy2() is used, but any
function that supports the same signature (like copy()) can be used.
"""
errors = []
if os.path.isdir(src):
names = os.listdir(src)
if ignore is not None:
ignored_names = ignore(src, names)
else:
ignored_names = set()
os.makedirs(dst)
for name in names:
if name in ignored_names:
continue
srcname = os.path.join(src, name)
dstname = os.path.join(dst, name)
try:
if os.path.islink(srcname):
linkto = os.readlink(srcname)
if symlinks:
# We can't just leave it to `copy_function` because legacy
# code with a custom `copy_function` may rely on copytree
# doing the right thing.
os.symlink(linkto, dstname)
else:
# ignore dangling symlink if the flag is on
if not os.path.exists(linkto) and ignore_dangling_symlinks:
continue
# otherwise let the copy occurs. copy2 will raise an error
if os.path.isdir(srcname):
slcopytree(srcname, dstname, symlinks, ignore,
copy_function)
else:
copy_function(srcname, dstname)
elif os.path.isdir(srcname):
slcopytree(srcname, dstname, symlinks, ignore, copy_function)
else:
# Will raise a SpecialFileError for unsupported file types
copy_function(srcname, dstname)
# catch the Error from the recursive copytree so that we can
# continue with other files
except Error as err:
errors.extend(err.args[0])
except OSError as why:
errors.append((srcname, dstname, str(why)))
else:
copy_function(src, dst)
if errors:
raise Error(errors)
return dst
def check_and_copy(src_path, tgt_path):
if os.path.exists(tgt_path):
return None
return slcopytree(src_path, tgt_path)
def remove(srcpath):
if os.path.isdir(srcpath):
return shutil.rmtree(srcpath)
else:
return os.remove(srcpath)
def preparing_dataset(pathdict, image_set, args):
start_time = time.time()
dataset_file = args.dataset_file
data_static_info = SLConfig.fromfile('util/static_data_path.py')
static_dict = data_static_info[dataset_file][image_set]
copyfilelist = []
for k,tgt_v in pathdict.items():
if os.path.exists(tgt_v):
if args.local_rank == 0:
print("path <{}> exist. remove it!".format(tgt_v))
remove(tgt_v)
# continue
if args.local_rank == 0:
src_v = static_dict[k]
assert isinstance(src_v, str)
if src_v.endswith('.zip'):
# copy
cp_tgt_dir = os.path.dirname(tgt_v)
filename = os.path.basename(src_v)
cp_tgt_path = os.path.join(cp_tgt_dir, filename)
print('Copy from <{}> to <{}>.'.format(src_v, cp_tgt_path))
os.makedirs(cp_tgt_dir, exist_ok=True)
check_and_copy(src_v, cp_tgt_path)
# unzip
import zipfile
print("Starting unzip <{}>".format(cp_tgt_path))
with zipfile.ZipFile(cp_tgt_path, 'r') as zip_ref:
zip_ref.extractall(os.path.dirname(cp_tgt_path))
copyfilelist.append(cp_tgt_path)
copyfilelist.append(tgt_v)
else:
print('Copy from <{}> to <{}>.'.format(src_v, tgt_v))
os.makedirs(os.path.dirname(tgt_v), exist_ok=True)
check_and_copy(src_v, tgt_v)
copyfilelist.append(tgt_v)
if len(copyfilelist) == 0:
copyfilelist = None
args.copyfilelist = copyfilelist
if args.distributed:
torch.distributed.barrier()
total_time = time.time() - start_time
if copyfilelist:
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Data copy time {}'.format(total_time_str))
return copyfilelist