| |
| |
|
|
| """ |
| Build training graphs from KWDLC with JUMANDIC. |
| |
| Pipeline: |
| 1) Read gold morphemes from KNP files |
| 2) Parse text with MeCab (JUMANDIC) to get candidate morphemes |
| 3) Match candidates to gold and assign annotations ('+', '-', '?') |
| 4) Save graph data as .pt |
| """ |
|
|
| import argparse |
| from collections import defaultdict |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import torch |
| import yaml |
| from tqdm import tqdm |
|
|
| from mecari.analyzers.mecab import MeCabAnalyzer |
| from mecari.data.data_module import DataModule |
| from mecari.featurizers.lexical import LexicalNGramFeaturizer as LexicalFeaturizer |
| from mecari.featurizers.lexical import Morpheme |
| from mecari.utils.morph_utils import build_adjacent_edges, dedup_morphemes, normalize_mecab_candidates |
|
|
|
|
| def add_lexical_features(morphemes: List[Dict], text: str, feature_dim: int = 100000) -> List[Dict]: |
| """Add lexical (index, value) pairs to morphemes. Not used when saving JSON. |
| |
| Kept for backward-compatibility and test equivalence. |
| """ |
| featurizer = LexicalFeaturizer(dim=feature_dim, add_bias=True) |
| for m in morphemes: |
| surf = m.get("surface", "") |
| morph_obj = Morpheme( |
| surf=surf, |
| lemma=m.get("base_form", surf), |
| pos=m.get("pos", "*"), |
| pos1=m.get("pos_detail1", "*"), |
| ctype="*", |
| cform="*", |
| reading=m.get("reading", "*"), |
| ) |
| st = m.get("start_pos", 0) |
| ed = m.get("end_pos", st + len(surf)) |
| prev_char = text[st - 1] if st > 0 and st <= len(text) else None |
| next_char = text[ed] if ed < len(text) else None |
| feats = featurizer.unigram_feats(morph_obj, prev_char, next_char) |
| m["lexical_features"] = feats |
| return morphemes |
|
|
|
|
| def hiragana_to_katakana(text: str) -> str: |
| """Convert hiragana to katakana.""" |
| return "".join([chr(ord(c) + 96) if "ぁ" <= c <= "ん" else c for c in text]) |
|
|
|
|
| def _load_gold_with_kyoto(knp_path: Path) -> List[Dict]: |
| """Load sentences and morphemes from a KNP file using kyoto-reader (required).""" |
| try: |
| from kyoto_reader import KyotoReader |
| except Exception as e: |
| raise RuntimeError("kyoto-reader is required for gold loading. Install it (pip install kyoto-reader).") from e |
|
|
| try: |
| try: |
| reader = KyotoReader(str(knp_path), n_jobs=0) |
| except TypeError: |
| reader = KyotoReader(str(knp_path)) |
| sents: List[Dict] = [] |
| for doc in reader.process_all_documents(n_jobs=0): |
| if doc is None: |
| continue |
| for sent in doc.sentences: |
| text = sent.surf |
| morphemes: List[Dict] = [] |
| pos = 0 |
| for mrph in sent.mrph_list(): |
| surf = getattr(mrph, "midasi", "") or "" |
| read = getattr(mrph, "yomi", surf) or surf |
| lemma = getattr(mrph, "genkei", surf) or surf |
| pos_main = getattr(mrph, "hinsi", "*") or "*" |
| pos1 = getattr(mrph, "bunrui", "*") or "*" |
| st = pos |
| ed = st + len(surf) |
| pos = ed |
| morphemes.append( |
| { |
| "surface": surf, |
| "reading": read, |
| "base_form": lemma, |
| "pos": pos_main, |
| "pos_detail1": pos1, |
| "pos_detail2": "*", |
| "pos_detail3": "*", |
| "start_pos": st, |
| "end_pos": ed, |
| } |
| ) |
| sents.append({"text": text, "morphemes": morphemes}) |
| return sents |
| except Exception as e: |
| raise RuntimeError(f"Failed to parse KNP with kyoto-reader: {knp_path}") from e |
|
|
|
|
| def match_morphemes_with_gold(candidates: List[Dict], gold_morphemes: List[Dict], text: str) -> List[Dict]: |
| """Match candidate morphemes to gold and assign annotations ('?', '+', '-'). |
| |
| Policy: |
| - Initialize every candidate as '?' |
| - Mark '+' for candidates that strictly match gold (surface, POS, base, reading) |
| - Mark '-' for candidates that overlap any '+' span |
| """ |
| |
| gold_details = [] |
| cur = 0 |
| for g in gold_morphemes: |
| surf = g.get("surface", "") |
| st, ed = cur, cur + len(surf) |
| cur = ed |
| gold_details.append( |
| { |
| "start_pos": st, |
| "end_pos": ed, |
| "surface": surf, |
| "pos": g.get("pos", "*"), |
| "pos_detail1": g.get("pos_detail1", "*"), |
| "pos_detail2": g.get("pos_detail2", "*"), |
| "base_form": g.get("base_form", ""), |
| "reading": hiragana_to_katakana(g.get("reading", "")), |
| } |
| ) |
|
|
| |
| annotated: List[Dict] = [] |
| for cand in candidates: |
| a = {**cand} |
| a["annotation"] = "?" |
| if "inflection_type" not in a: |
| a["inflection_type"] = "*" |
| if "inflection_form" not in a: |
| a["inflection_form"] = "*" |
| annotated.append(a) |
|
|
| |
| span_to_cands: dict[tuple[int, int], list[Dict]] = {} |
| for a in annotated: |
| cs = a.get("start_pos", 0) |
| ce = a.get("end_pos", cs + len(a.get("surface", ""))) |
| span_to_cands.setdefault((cs, ce), []).append(a) |
|
|
| matched_spans: List[tuple[int, int]] = [] |
| for g in gold_details: |
| span = (g["start_pos"], g["end_pos"]) |
| cands = span_to_cands.get(span, []) |
| if not cands: |
| continue |
| strict = [] |
| fallback = [] |
| for a in cands: |
| if a.get("surface", "") != g["surface"]: |
| continue |
| if a.get("pos", "*") != g["pos"]: |
| continue |
| if a.get("pos_detail1", "*") != g.get("pos_detail1", "*"): |
| continue |
| if a.get("base_form", "") != g["base_form"]: |
| continue |
| if hiragana_to_katakana(a.get("reading", "")) == g["reading"]: |
| strict.append(a) |
| else: |
| fallback.append(a) |
| chosen_list = strict if strict else fallback |
| if chosen_list: |
| for a in chosen_list: |
| a["annotation"] = "+" |
| matched_spans.append(span) |
| for a in cands: |
| if (a not in chosen_list) and a.get("annotation") != "+": |
| a["annotation"] = "-" |
|
|
| |
| plus_spans = [] |
| for a in annotated: |
| if a.get("annotation") == "+": |
| cs = a.get("start_pos", 0) |
| ce = a.get("end_pos", cs + len(a.get("surface", ""))) |
| plus_spans.append((cs, ce)) |
|
|
| def _strict_overlap(st1: int, ed1: int, st2: int, ed2: int) -> bool: |
| |
| return max(st1, st2) < min(ed1, ed2) |
|
|
| for a in annotated: |
| if a.get("annotation") == "+": |
| continue |
| cs = a.get("start_pos", 0) |
| ce = a.get("end_pos", cs + len(a.get("surface", ""))) |
| for ms, me in plus_spans: |
| if _strict_overlap(cs, ce, ms, me): |
| a["annotation"] = "-" |
| break |
| return annotated |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Create training data from KWDLC (JUMANDIC)") |
| parser.add_argument("--input-dir", type=str, default="KWDLC/knp", help="Directory containing KNP files") |
| parser.add_argument("--config", type=str, default="configs/gat.yaml", help="Path to config file") |
| parser.add_argument("--limit", type=int, help="Max number of files to process") |
| parser.add_argument("--test-only", action="store_true", help="Process only test split IDs") |
| parser.add_argument("--jumandic-path", type=str, default="/var/lib/mecab/dic/juman-utf8", help="Path to JUMANDIC") |
| args = parser.parse_args() |
|
|
| config = {} |
| if args.config and Path(args.config).exists(): |
| with open(args.config, "r") as f: |
| config = yaml.safe_load(f) |
|
|
| if "extends" in config: |
| parent_config_path = Path(args.config).parent / config["extends"] |
| if parent_config_path.exists(): |
| with open(parent_config_path, "r") as f: |
| parent_config = yaml.safe_load(f) |
|
|
| def deep_merge(base, override): |
| for key, value in override.items(): |
| if key in base and isinstance(base[key], dict) and isinstance(value, dict): |
| deep_merge(base[key], value) |
| else: |
| base[key] = value |
| return base |
|
|
| config = deep_merge(parent_config, config) |
|
|
| features_config = config.get("features", {}) |
| feature_dim = features_config.get("lexical_feature_dim", 100000) |
| training_config = config.get("training", {}) |
|
|
| if training_config.get("annotations_dir"): |
| output_dir = Path(training_config.get("annotations_dir")) |
| else: |
| output_dir = Path("annotations_kwdlc_juman") |
| output_dir.mkdir(parents=True, exist_ok=True) |
| print(f"Lexical features: using {feature_dim} dims") |
| print(f"Output directory: {output_dir}") |
|
|
| analyzer = MeCabAnalyzer( |
| jumandic_path=args.jumandic_path, |
| ) |
|
|
| knp_files = [] |
|
|
| if args.test_only: |
| test_id_file = Path("KWDLC/id/split_for_pas/test.id") |
| if test_id_file.exists(): |
| with open(test_id_file, "r") as f: |
| test_ids = [line.strip() for line in f if line.strip()] |
|
|
| knp_base_dir = Path(args.input_dir) |
| for file_id in test_ids: |
| dir_name = file_id[:13] |
| file_name = f"{file_id}.knp" |
| knp_path = knp_base_dir / dir_name / file_name |
| if knp_path.exists(): |
| knp_files.append(knp_path) |
| else: |
| knp_dir = Path(args.input_dir) |
| knp_files = sorted(knp_dir.glob("**/*.knp")) |
|
|
| if args.limit: |
| knp_files = knp_files[: args.limit] |
|
|
| print(f"Files to process: {len(knp_files)}") |
| print(f"JUMANDIC: {args.jumandic_path}") |
| print(f"Output to: {output_dir}") |
|
|
| total_stats = defaultdict(int) |
| annotation_idx = 0 |
|
|
| dm = DataModule( |
| annotations_dir=str(output_dir), |
| lexical_feature_dim=int(feature_dim), |
| use_bidirectional_edges=bool(config.get("edge_features", {}).get("use_bidirectional_edges", True)), |
| ) |
|
|
| |
|
|
| for knp_path in tqdm(knp_files, desc="processing"): |
| try: |
| sentences = _load_gold_with_kyoto(knp_path) |
| if not sentences: |
| continue |
|
|
| doc_id = knp_path.stem |
| for s in sentences: |
| s["source_id"] = doc_id |
|
|
| for sent_idx, sentence in enumerate(sentences): |
| text = sentence["text"] |
| gold_morphemes = sentence["morphemes"] |
| source_id = sentence.get("source_id", doc_id) |
|
|
| candidates = analyzer.get_morpheme_candidates(text) |
| candidates = normalize_mecab_candidates(candidates) |
| candidates = dedup_morphemes(candidates) |
| if not candidates: |
| continue |
|
|
| annotated_morphemes = match_morphemes_with_gold(candidates, gold_morphemes, text) |
|
|
| edges = build_adjacent_edges(annotated_morphemes) |
|
|
| for m in annotated_morphemes: |
| if "lexical_features" in m: |
| m.pop("lexical_features", None) |
|
|
| morphemes_with_feats = dm.compute_lexical_features(annotated_morphemes, text) |
| graph = dm.create_graph_from_morphemes_data( |
| morphemes=morphemes_with_feats, |
| edges=edges, |
| text=text, |
| for_training=True, |
| ) |
| if graph is None: |
| continue |
|
|
| graph_file = output_dir / f"graph_{annotation_idx:04d}.pt" |
| payload = { |
| "graph": graph, |
| "source_id": source_id, |
| "text": text, |
| } |
| torch.save(payload, graph_file) |
|
|
| total_stats["sentences"] += 1 |
| total_stats["morphemes"] += len(annotated_morphemes) |
| total_stats["positive"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "+") |
| total_stats["negative"] += sum(1 for m in annotated_morphemes if m.get("annotation") == "-") |
|
|
| annotation_idx += 1 |
|
|
| total_stats["files"] += 1 |
|
|
| except Exception as e: |
| print(f"Error ({knp_path}): {e}") |
| total_stats["errors"] += 1 |
|
|
| print("\n" + "=" * 50) |
| print("Processing complete") |
| print("=" * 50) |
| print(f"Files: {total_stats['files']}") |
| print(f"Sentences: {total_stats['sentences']}") |
| print(f"Morphemes: {total_stats['morphemes']}") |
| print(f"Positive (+): {total_stats['positive']}") |
| print(f"Negative (-): {total_stats['negative']}") |
| |
| if total_stats["errors"] > 0: |
| print(f"Errors: {total_stats['errors']}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|