| import argparse |
| import os |
|
|
|
|
| def main(): |
| src_lang, tgt_lang = args.lang_pair.split("-") |
| tgt_file_list = [file for file in os.listdir(args.tgt_path) if file.endswith(f".{tgt_lang}")] |
| for tgt_file in tgt_file_list: |
| src_file = os.path.splitext(tgt_file)[0] |
| doc_id = src_file.split('.')[-1] |
| label_file = src_file.replace(f".{src_lang}.", ".id.") |
| with open(os.path.join(args.disturb_src_path, label_file), "r", encoding="utf-8") as f: |
| labels = [line.strip() for line in f] |
| with open(os.path.join(args.tgt_path, tgt_file), "r", encoding="utf-8") as f: |
| tgt_lines = [line.strip() for line in f] |
| |
| assert len(labels) == len(tgt_lines), f"Length mismatch in {src_file} and {label_file}" |
| filterd_tgt_lines = [tgt for tgt, label in zip(tgt_lines, labels) if label.split('-')[0] == doc_id] |
| |
| with open(os.path.join(args.original_src_path, src_file), "r", encoding="utf-8") as f: |
| original_src_lines = [line.strip() for line in f] |
| assert len(original_src_lines) == len(filterd_tgt_lines), f"Length mismatch in {src_file} and filtered {tgt_file}" |
| |
| with open(os.path.join(args.output_path, tgt_file), "w", encoding="utf-8") as f: |
| f.write("\n".join(filterd_tgt_lines) + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--original_src_path", type=str) |
| parser.add_argument("--disturb_src_path", type=str) |
| parser.add_argument("--tgt_path", type=str) |
| parser.add_argument("--output_path", type=str) |
| parser.add_argument("--lang_pair", type=str) |
| args = parser.parse_args() |
| |
| os.makedirs(args.output_path, exist_ok=True) |
| |
| main() |
|
|