CPM-9G-8B/quick_start_clean/data_binarize.py

164 lines
5.8 KiB
Python

import concurrent
import os
import threading
import argparse
from concurrent.futures import ThreadPoolExecutor, wait
import time
import json
import random
import sys
import re
from tqdm import tqdm
sys.path.insert(0, '/data/public/CPM-9G/9G-Train')
from cpm.dataset import build_dataset, SimpleDataset
sys.setrecursionlimit(2000)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="raw dataset path", required=True)
parser.add_argument("--data_type", type=str, help="datatype can be json or txt", required=True)
parser.add_argument("--output_path", type=str, help="output dataset path", required=True)
parser.add_argument("--output_name", type=str, help="output dataset name", required=True)
parser.add_argument("--repair_keys", type=str, help="json string to json", required=False,
default="")
args = parser.parse_args()
return args
DROP_LEN = 100 ###### if the length of a piece of data is less than drop_len, drop it
MAX_LENGTH = 10000 ######the max length of final data
MAX_RECURSION = 1000
def split_sent(data_, depth):
# 用于纯文本切分
if len(data_) < DROP_LEN:
return []
if len(data_) > MAX_LENGTH and depth < MAX_RECURSION:
if '\n' not in data_:
return [{"text": data_}]
# return [data_]
mid = int(len(data_) / 2)
while mid > 0 and (data_[mid - 1] not in ["\n", ""]): ######\n, 。等分割符号,根据自己需要改
mid -= 1
ret = []
ret.extend(split_sent(data_[:mid], depth + 1))
ret.extend(split_sent(data_[mid:], depth + 1))
return ret
else:
# return [data_]
return [{"text": data_}]
# import orjson
import time
import fcntl
TARGET_LANG = {"julia", "visual-basic", "java", "python", "c-sharp", "c", "cpp", "scala", "javascript", "go", "rust",
"ruby", "haskell", "typescript", "lua", "php", "fortran", "r", "sql", "jupyter-scripts-julia",
"jupyter-scripts-java", "jupyter-scripts-python", "jupyter-scripts-csharp", "jupyter-scripts-c++",
"jupyter-scripts-scala", "jupyter-scripts-javascript", "jupyter-scripts-rust", "jupyter-scripts-haskell",
"jupyter-scripts-typescript", "jupyter-scripts-R", "jupiter-structured", "git-commit"}
def is_target_lang(json_obj):
return json_obj.get("clean_content", {}).get("lang", "") in TARGET_LANG
def deal_json_file(file_path, ds_write, repair_keys=None):
print(f"begin deal {file_path}")
t0 = time.time()
with open(file_path, "r", encoding='utf-8') as fin:
data_buffer = []
for line in fin:
line = line.strip()
data = load_and_repair_json_string(line, repair_keys)
data_buffer.append(data)
if len(data_buffer) > 64:
global T_LOCK
if T_LOCK:
T_LOCK.acquire()
for data in data_buffer:
ds_write.write(data)
T_LOCK.release()
else:
for data in data_buffer:
ds_write.write(data)
data_buffer = []
if T_LOCK:
T_LOCK.acquire()
for data in data_buffer:
ds_write.write(data)
T_LOCK.release()
print(f"deal {os.path.basename(file_path)} time spend {time.time() - t0}")
def load_and_repair_json_string(line, repair_keys=None):
data = json.loads(line)
if repair_keys:
for key in repair_keys:
if data[key] is not None and isinstance(data[key], str):
data[key] = json.loads(data[key])
return data
T_LOCK = None
def main():
args = get_args()
file_list = []
for file_i in os.listdir(args.input):
tmp_dir = os.path.join(args.input, file_i)
if os.path.isfile(tmp_dir):
file_list.append(tmp_dir)
else:
for file_i_i in os.listdir(tmp_dir):
file_list.append(os.path.join(tmp_dir, file_i_i))
repair_keys = args.repair_keys.strip().split(",")
if len(repair_keys) == 1 and repair_keys[0] == '':
repair_keys = None
file_list.sort()
t0 = time.time()
with build_dataset(args.output_path, args.output_name) as dataset:
if args.data_type == "txt":
for ds in file_list:
print(ds)
with open(ds, "r", encoding='utf-8') as fin:
for line in tqdm(fin):
line = json.loads(line.strip())
line = line.strip().replace("\\r\\n", "\n")
line = line.strip().replace("\\r", "\n")
line = line.strip().replace("<n>", "\n")
line = line.strip().replace("\\n", "\n") ######清洗步骤根据自己需要改
line = re.sub('\n\n[\n]+', '\n\n', line.strip())
line = re.sub('(\n\s+\n)+', '\n\n', line.strip())
line_list = split_sent(line, 1) #######递归切分line
for item in line_list:
dataset.write(item)
# dataset.write({"text":line})
elif args.data_type == "json":
global T_LOCK
T_LOCK = threading.Lock()
thread_pool = ThreadPoolExecutor(max_workers=1)
tasks = []
for ds_path in file_list:
# deal_json_file(ds_path, dataset, repair_keys)
tasks.append(thread_pool.submit(deal_json_file, *(ds_path, dataset, repair_keys)))
wait(tasks)
for task in tasks:
if task.result():
pass
print(f"all time spend:{time.time() - t0}")
if __name__ == "__main__":
main()