@@ -17,48 +17,43 @@ def isCreateOrDeleteFolder(path, flag):
17
17
return flagAbsPath
18
18
19
19
20
- def splitTrainVal (root , absTrainRootPath , absValRootPath , absTestRootPath , trainTxt , valTxt , testTxt , flag ):
21
- # 按照指定的比例划分训练集、验证集、测试集
22
- dataAbsPath = os .path .abspath (root )
23
-
24
- if flag == "det" :
25
- labelFilePath = os .path .join (dataAbsPath , args .detLabelFileName )
26
- elif flag == "rec" :
27
- labelFilePath = os .path .join (dataAbsPath , args .recLabelFileName )
28
-
29
- labelFileRead = open (labelFilePath , "r" , encoding = "UTF-8" )
30
- labelFileContent = labelFileRead .readlines ()
31
- random .shuffle (labelFileContent )
32
- labelRecordLen = len (labelFileContent )
33
-
34
- for index , labelRecordInfo in enumerate (labelFileContent ):
35
- imageRelativePath = labelRecordInfo .split ('\t ' )[0 ]
36
- imageLabel = labelRecordInfo .split ('\t ' )[1 ]
37
- imageName = os .path .basename (imageRelativePath )
38
-
39
- if flag == "det" :
40
- imagePath = os .path .join (dataAbsPath , imageName )
41
- elif flag == "rec" :
42
- imagePath = os .path .join (dataAbsPath , "{}\\ {}" .format (args .recImageDirName , imageName ))
43
-
44
- # 按预设的比例划分训练集、验证集、测试集
45
- trainValTestRatio = args .trainValTestRatio .split (":" )
46
- trainRatio = eval (trainValTestRatio [0 ]) / 10
47
- valRatio = trainRatio + eval (trainValTestRatio [1 ]) / 10
48
- curRatio = index / labelRecordLen
49
-
50
- if curRatio < trainRatio :
51
- imageCopyPath = os .path .join (absTrainRootPath , imageName )
52
- shutil .copy (imagePath , imageCopyPath )
53
- trainTxt .write ("{}\t {}" .format (imageCopyPath , imageLabel ))
54
- elif curRatio >= trainRatio and curRatio < valRatio :
55
- imageCopyPath = os .path .join (absValRootPath , imageName )
56
- shutil .copy (imagePath , imageCopyPath )
57
- valTxt .write ("{}\t {}" .format (imageCopyPath , imageLabel ))
58
- else :
59
- imageCopyPath = os .path .join (absTestRootPath , imageName )
60
- shutil .copy (imagePath , imageCopyPath )
61
- testTxt .write ("{}\t {}" .format (imageCopyPath , imageLabel ))
20
+ def splitTrainVal (root , abs_train_root_path , abs_val_root_path , abs_test_root_path , train_txt , val_txt , test_txt , flag ):
21
+
22
+ data_abs_path = os .path .abspath (root )
23
+ label_file_name = args .detLabelFileName if flag == "det" else args .recLabelFileName
24
+ label_file_path = os .path .join (data_abs_path , label_file_name )
25
+
26
+ with open (label_file_path , "r" , encoding = "UTF-8" ) as label_file :
27
+ label_file_content = label_file .readlines ()
28
+ random .shuffle (label_file_content )
29
+ label_record_len = len (label_file_content )
30
+
31
+ for index , label_record_info in enumerate (label_file_content ):
32
+ image_relative_path , image_label = label_record_info .split ('\t ' )
33
+ image_name = os .path .basename (image_relative_path )
34
+
35
+ if flag == "det" :
36
+ image_path = os .path .join (data_abs_path , image_name )
37
+ elif flag == "rec" :
38
+ image_path = os .path .join (data_abs_path , args .recImageDirName , image_name )
39
+
40
+ train_val_test_ratio = args .trainValTestRatio .split (":" )
41
+ train_ratio = eval (train_val_test_ratio [0 ]) / 10
42
+ val_ratio = train_ratio + eval (train_val_test_ratio [1 ]) / 10
43
+ cur_ratio = index / label_record_len
44
+
45
+ if cur_ratio < train_ratio :
46
+ image_copy_path = os .path .join (abs_train_root_path , image_name )
47
+ shutil .copy (image_path , image_copy_path )
48
+ train_txt .write ("{}\t {}\n " .format (image_copy_path , image_label ))
49
+ elif cur_ratio >= train_ratio and cur_ratio < val_ratio :
50
+ image_copy_path = os .path .join (abs_val_root_path , image_name )
51
+ shutil .copy (image_path , image_copy_path )
52
+ val_txt .write ("{}\t {}\n " .format (image_copy_path , image_label ))
53
+ else :
54
+ image_copy_path = os .path .join (abs_test_root_path , image_name )
55
+ shutil .copy (image_path , image_copy_path )
56
+ test_txt .write ("{}\t {}\n " .format (image_copy_path , image_label ))
62
57
63
58
64
59
# 删掉存在的文件
@@ -148,4 +143,4 @@ def genDetRecTrainVal(args):
148
143
help = "the name of the folder where the cropped recognition dataset is located"
149
144
)
150
145
args = parser .parse_args ()
151
- genDetRecTrainVal (args )
146
+ genDetRecTrainVal (args )
0 commit comments