CPM-9G-8B/quick_start_clean/data_binarize.py

164 lines
5.8 KiB
Python
Raw Permalink Normal View History

2024-02-28 17:42:58 +08:00
import concurrent
import os
import threading
import argparse
from concurrent.futures import ThreadPoolExecutor, wait
import time
import json
import random
import sys
import re
from tqdm import tqdm
sys.path.insert(0, '/data/public/CPM-9G/9G-Train')
from cpm.dataset import build_dataset, SimpleDataset
sys.setrecursionlimit(2000)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="raw dataset path", required=True)
parser.add_argument("--data_type", type=str, help="datatype can be json or txt", required=True)
parser.add_argument("--output_path", type=str, help="output dataset path", required=True)
parser.add_argument("--output_name", type=str, help="output dataset name", required=True)
parser.add_argument("--repair_keys", type=str, help="json string to json", required=False,
default="")
args = parser.parse_args()
return args
DROP_LEN = 100 ###### if the length of a piece of data is less than drop_len, drop it
MAX_LENGTH = 10000 ######the max length of final data
MAX_RECURSION = 1000
def split_sent(data_, depth):
# 用于纯文本切分
if len(data_) < DROP_LEN:
return []
if len(data_) > MAX_LENGTH and depth < MAX_RECURSION:
if '\n' not in data_:
return [{"text": data_}]
# return [data_]
mid = int(len(data_) / 2)
while mid > 0 and (data_[mid - 1] not in ["\n", ""]): ######\n, 。等分割符号,根据自己需要改
mid -= 1
ret = []
ret.extend(split_sent(data_[:mid], depth + 1))
ret.extend(split_sent(data_[mid:], depth + 1))
return ret
else:
# return [data_]
return [{"text": data_}]
# import orjson
import time
import fcntl
TARGET_LANG = {"julia", "visual-basic", "java", "python", "c-sharp", "c", "cpp", "scala", "javascript", "go", "rust",
"ruby", "haskell", "typescript", "lua", "php", "fortran", "r", "sql", "jupyter-scripts-julia",
"jupyter-scripts-java", "jupyter-scripts-python", "jupyter-scripts-csharp", "jupyter-scripts-c++",
"jupyter-scripts-scala", "jupyter-scripts-javascript", "jupyter-scripts-rust", "jupyter-scripts-haskell",
"jupyter-scripts-typescript", "jupyter-scripts-R", "jupiter-structured", "git-commit"}
def is_target_lang(json_obj):
return json_obj.get("clean_content", {}).get("lang", "") in TARGET_LANG
def deal_json_file(file_path, ds_write, repair_keys=None):
print(f"begin deal {file_path}")
t0 = time.time()
with open(file_path, "r", encoding='utf-8') as fin:
data_buffer = []
for line in fin:
line = line.strip()
data = load_and_repair_json_string(line, repair_keys)
data_buffer.append(data)
if len(data_buffer) > 64:
global T_LOCK
if T_LOCK:
T_LOCK.acquire()
for data in data_buffer:
ds_write.write(data)
T_LOCK.release()
else:
for data in data_buffer:
ds_write.write(data)
data_buffer = []
if T_LOCK:
T_LOCK.acquire()
for data in data_buffer:
ds_write.write(data)
T_LOCK.release()
print(f"deal {os.path.basename(file_path)} time spend {time.time() - t0}")
def load_and_repair_json_string(line, repair_keys=None):
data = json.loads(line)
if repair_keys:
for key in repair_keys:
if data[key] is not None and isinstance(data[key], str):
data[key] = json.loads(data[key])
return data
T_LOCK = None
def main():
args = get_args()
file_list = []
for file_i in os.listdir(args.input):
tmp_dir = os.path.join(args.input, file_i)
if os.path.isfile(tmp_dir):
file_list.append(tmp_dir)
else:
for file_i_i in os.listdir(tmp_dir):
file_list.append(os.path.join(tmp_dir, file_i_i))
repair_keys = args.repair_keys.strip().split(",")
if len(repair_keys) == 1 and repair_keys[0] == '':
repair_keys = None
file_list.sort()
t0 = time.time()
with build_dataset(args.output_path, args.output_name) as dataset:
if args.data_type == "txt":
for ds in file_list:
print(ds)
with open(ds, "r", encoding='utf-8') as fin:
for line in tqdm(fin):
line = json.loads(line.strip())
line = line.strip().replace("\\r\\n", "\n")
line = line.strip().replace("\\r", "\n")
line = line.strip().replace("<n>", "\n")
line = line.strip().replace("\\n", "\n") ######清洗步骤根据自己需要改
line = re.sub('\n\n[\n]+', '\n\n', line.strip())
line = re.sub('(\n\s+\n)+', '\n\n', line.strip())
line_list = split_sent(line, 1) #######递归切分line
for item in line_list:
dataset.write(item)
# dataset.write({"text":line})
elif args.data_type == "json":
global T_LOCK
T_LOCK = threading.Lock()
thread_pool = ThreadPoolExecutor(max_workers=1)
tasks = []
for ds_path in file_list:
# deal_json_file(ds_path, dataset, repair_keys)
tasks.append(thread_pool.submit(deal_json_file, *(ds_path, dataset, repair_keys)))
wait(tasks)
for task in tasks:
if task.result():
pass
print(f"all time spend:{time.time() - t0}")
if __name__ == "__main__":
main()