File size: 13,259 Bytes
5ededda
1
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!pip install -q datasets transformers seqeval evaluate\n\nimport torch\nimport numpy as np\nfrom datasets import load_dataset\nfrom transformers import (AutoTokenizer, \n                          AutoModelForTokenClassification, \n                          DataCollatorForTokenClassification, \n                          TrainingArguments, \n                          Trainer)\nimport evaluate\n\nMODEL_NAME = \"distilroberta-base\"\nDATASET_NAME = \"ai4privacy/pii-masking-200k\"\nMAX_LENGTH = 512","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import numpy as np\nfrom datasets import load_dataset\n\nprint(\"Loading dataset...\")\nraw_datasets = load_dataset(DATASET_NAME, split='train[:20000]').train_test_split(test_size=0.1)\n\nprint(\"Extracting label list...\")\nunique_labels = set()\nfor mask_list in raw_datasets[\"train\"][\"privacy_mask\"]:\n    for item in mask_list:\n        unique_labels.add(item[\"label\"])\n\nlabel_list = [\"O\"] + sorted(list(unique_labels))\nlabel2id = {l: i for i, l in enumerate(label_list)}\nid2label = {i: l for i, l in enumerate(label_list)}\n\nprint(f\"Found classes incl. 'O': {label_list[:5]}...\")\n\ndef align_labels_with_spans(examples):\n    tokenized_inputs = tokenizer(\n        examples[\"source_text\"], \n        truncation=True, \n        max_length=MAX_LENGTH, \n        return_offsets_mapping=True, \n        padding=False\n    )\n    \n    all_labels = []\n    \n    for i, spans in enumerate(examples[\"privacy_mask\"]):\n        offsets = tokenized_inputs[\"offset_mapping\"][i]\n        token_labels = []\n        \n        for idx, (o_start, o_end) in enumerate(offsets):\n            if o_start == 0 and o_end == 0:\n                token_labels.append(-100)\n                continue\n                \n            label_id = 0\n            \n            for span in spans:\n                if o_start >= span[\"start\"] and o_end <= span[\"end\"]:\n                    label_id = label2id[span[\"label\"]]\n                    break\n            \n            token_labels.append(label_id)\n            \n        all_labels.append(token_labels)\n    \n    tokenized_inputs[\"labels\"] = all_labels\n    return tokenized_inputs\n\nprint(\"Tokenizing...\")\ntokenized_datasets = raw_datasets.map(\n    align_labels_with_spans, \n    batched=True, \n    remove_columns=raw_datasets[\"train\"].column_names\n)\n\nprint(\"Data preparation for Shield 82M done.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"metric = evaluate.load(\"seqeval\")\n\ndef compute_metrics(p):\n    predictions, labels = p\n    predictions = np.argmax(predictions, axis=2)\n\n    true_predictions = [\n        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n        for prediction, label in zip(predictions, labels)\n    ]\n    true_labels = [\n        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n        for prediction, label in zip(predictions, labels)\n    ]\n    results = metric.compute(predictions=true_predictions, references=true_labels)\n    return {\n        \"precision\": results[\"overall_precision\"],\n        \"recall\": results[\"overall_recall\"],\n        \"f1\": results[\"overall_f1\"],\n        \"accuracy\": results[\"overall_accuracy\"],\n    }\n\nmodel = AutoModelForTokenClassification.from_pretrained(\n    MODEL_NAME, num_labels=len(label_list), id2label=id2label, label2id=label2id\n)","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def count_parameters(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\nparams = count_parameters(model)\nprint(f\"The model has {params:,} trainable params.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from transformers import DataCollatorForTokenClassification\n\ntraining_args = TrainingArguments(\n    output_dir=\"./Shield\",\n    eval_strategy=\"epoch\",\n    learning_rate=2e-5,\n    per_device_train_batch_size=16,\n    per_device_eval_batch_size=16,\n    num_train_epochs=3,\n    weight_decay=0.01,\n    report_to=\"none\",\n    save_strategy=\"epoch\",\n    load_best_model_at_end=True\n)\n\ndata_collator = DataCollatorForTokenClassification(tokenizer)\n\ntrainer = Trainer(\n    model=model,\n    args=training_args,\n    train_dataset=tokenized_datasets[\"train\"],\n    eval_dataset=tokenized_datasets[\"test\"],\n    data_collator=data_collator,\n    compute_metrics=compute_metrics,\n)\n\ntrainer.train()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\n\nGROUPS = {\n    \"FIRSTNAME\": \"PERSON\", \"MIDDLENAME\": \"PERSON\", \"LASTNAME\": \"PERSON\",\n    \"BUILDINGNUMBER\": \"ADDRESS\", \"STREET\": \"ADDRESS\", \"CITY\": \"ADDRESS\", \n    \"STATE\": \"ADDRESS\", \"ZIPCODE\": \"ADDRESS\", \"SECONDARYADDRESS\": \"ADDRESS\",\n    \"EMAIL\": \"EMAIL\", \"PHONENUMBER\": \"PHONE\", \"PHONEIMEI\": \"PHONE\",\n    \"DATE\": \"DOB\", \"TIME\": \"DOB\"\n}\n\ndef shield_filter_production(text):\n    inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=512, return_offsets_mapping=True).to(model.device)\n    offsets = inputs.pop(\"offset_mapping\")[0].cpu().numpy()\n    \n    with torch.no_grad():\n        outputs = model(**inputs).logits\n        \n    predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy()\n    \n    spans_to_replace = []\n    current_group = None\n    start_char = -1\n    last_char = -1\n    \n    for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)):\n        if offset[0] == 0 and offset[1] == 0:\n            continue\n            \n        label = id2label[pred_id]\n        \n        if label == \"O\":\n            if current_group is not None:\n                spans_to_replace.append((start_char, last_char, current_group))\n                current_group = None\n        else:\n            group_tag = GROUPS.get(label, label)\n            \n            if current_group != group_tag:\n                if current_group is not None:\n                    spans_to_replace.append((start_char, last_char, current_group))\n                current_group = group_tag\n                start_char = offset[0]\n                \n            last_char = offset[1]\n            \n    if current_group is not None:\n        spans_to_replace.append((start_char, last_char, current_group))\n        \n    filtered_text = text\n    for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True):\n        filtered_text = filtered_text[:start] + f\"[{tag}]\" + filtered_text[end:]\n        \n    return filtered_text\n\ntests = [\n    \"Mein Name ist Max Mustermann und ich wohne in der Hauptstraße 5, Berlin. Meine Email ist max@example.com.\",\n    \"Liebe Lena, ich möchte dir heute mitteilen, dass ich ins Altmühltal umgezogen bin.\",\n    \"Alice was born on 1990-01-02 and lives at 1 Main St.\",\n    \"Mon e-mail est jean.dupont@example.fr et mon téléphone est +33 6 12 34 56 78.\"\n]\n\nfor t in tests:\n    print(f\"In:  {t}\")\n    print(f\"Out: {shield_filter_production(t)}\\n\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import os\n\nsave_directory = \"./Shield-v1-final\"\nos.makedirs(save_directory, exist_ok=True)\n\ntrainer.save_model(save_directory)\ntokenizer.save_pretrained(save_directory)\n\nprint(f\"Model files saved to {save_directory}.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import shutil\n\nshutil.make_archive(\"Shield_v1_Model\", 'zip', save_directory)\n\nprint(\"Success! Zipped successfully :D\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import torch\nfrom transformers import AutoTokenizer, AutoModelForTokenClassification\n\nclass ShieldFilter:\n    def __init__(self, model_path=\"LH-Tech-AI/Shield-82M\"):\n        print(f\"Loading Shield-82M from {model_path}...\")\n        self.tokenizer = AutoTokenizer.from_pretrained(model_path)\n        self.model = AutoModelForTokenClassification.from_pretrained(model_path)\n        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        self.model.to(self.device)\n        self.model.eval()\n\n        self.group_map = {\n            # Personal\n            \"FIRSTNAME\": \"PERSON\", \"MIDDLENAME\": \"PERSON\", \"LASTNAME\": \"PERSON\", \n            \"USERNAME\": \"PERSON\", \"PREFIX\": \"PERSON\", \"AGE\": \"AGE\", \"GENDER\": \"GENDER\", \"SEX\": \"GENDER\",\n            \n            # Adress and location\n            \"BUILDINGNUMBER\": \"ADDRESS\", \"STREET\": \"ADDRESS\", \"CITY\": \"ADDRESS\", \n            \"STATE\": \"ADDRESS\", \"ZIPCODE\": \"ADDRESS\", \"SECONDARYADDRESS\": \"ADDRESS\",\n            \"COUNTY\": \"ADDRESS\", \"NEARBYGPSCOORDINATE\": \"LOCATION\", \"ORDINALDIRECTION\": \"LOCATION\",\n            \n            # Contact\n            \"EMAIL\": \"EMAIL\", \"PHONENUMBER\": \"PHONE\", \"PHONEIMEI\": \"PHONE\", \"URL\": \"URL\",\n            \n            # Finances\n            \"IBAN\": \"BANK_ACCOUNT\", \"BIC\": \"BANK_ACCOUNT\", \"ACCOUNTNUMBER\": \"BANK_ACCOUNT\",\n            \"CREDITCARDNUMBER\": \"CREDIT_CARD\", \"CREDITCARDCVV\": \"CREDIT_CARD\", \"CREDITCARDISSUER\": \"CREDIT_CARD\",\n            \"BITCOINADDRESS\": \"CRYPTO\", \"ETHEREUMADDRESS\": \"CRYPTO\", \"LITECOINADDRESS\": \"CRYPTO\",\n            \"AMOUNT\": \"AMOUNT\", \"CURRENCY\": \"AMOUNT\", \"CURRENCYCODE\": \"AMOUNT\", \n            \"CURRENCYNAME\": \"AMOUNT\", \"CURRENCYSYMBOL\": \"AMOUNT\",\n            \n            # IT & Security\n            \"IP\": \"IT_INFO\", \"IPV4\": \"IT_INFO\", \"IPV6\": \"IT_INFO\", \"MAC\": \"IT_INFO\", \n            \"PASSWORD\": \"PASSWORD\", \"PIN\": \"PASSWORD\", \"USERAGENT\": \"IT_INFO\",\n            \n            # Work\n            \"COMPANYNAME\": \"ORGANIZATION\", \"JOBTITLE\": \"JOB\", \"JOBAREA\": \"JOB\", \"JOBTYPE\": \"JOB\",\n            \n            # Documents and vehicles\n            \"SSN\": \"ID_DOC\", \"VEHICLEVIN\": \"VEHICLE\", \"VEHICLEVRM\": \"VEHICLE\",\n            \n            # Time\n            \"DATE\": \"DOB\", \"DOB\": \"DOB\", \"TIME\": \"TIME\"\n        }\n\n    def protect(self, text):\n        inputs = self.tokenizer(\n            text, \n            return_tensors=\"pt\", \n            truncation=True, \n            max_length=512, \n            return_offsets_mapping=True\n        ).to(self.device)\n        \n        offsets = inputs.pop(\"offset_mapping\")[0].cpu().numpy()\n        \n        with torch.no_grad():\n            outputs = self.model(**inputs).logits\n            \n        predictions = torch.argmax(outputs, dim=2)[0].cpu().numpy()\n        id2label = self.model.config.id2label\n        \n        spans_to_replace = []\n        current_group = None\n        start_char = -1\n        last_char = -1\n        \n        for idx, (pred_id, offset) in enumerate(zip(predictions, offsets)):\n            if offset[0] == 0 and offset[1] == 0:\n                continue\n                \n            label = id2label[pred_id]\n            \n            if label == \"O\":\n                if current_group is not None:\n                    spans_to_replace.append((start_char, last_char, current_group))\n                    current_group = None\n            else:\n                group_tag = self.group_map.get(label, label)\n                \n                if current_group != group_tag:\n                    if current_group is not None:\n                        spans_to_replace.append((start_char, last_char, current_group))\n                    current_group = group_tag\n                    start_char = offset[0]\n                    \n                last_char = offset[1]\n                \n        if current_group is not None:\n            spans_to_replace.append((start_char, last_char, current_group))\n            \n        filtered_text = text\n        for start, end, tag in sorted(spans_to_replace, key=lambda x: x[0], reverse=True):\n            filtered_text = filtered_text[:start] + f\"[{tag}]\" + filtered_text[end:]\n            \n        return filtered_text\n\nif __name__ == \"__main__\":\n    shield = ShieldFilter()\n    sample = \"My name is John Doe. Email: john@example.com. Phone: +49 123 45678.\"\n    print(f\"Original: {sample}\")\n    print(f\"Protected: {shield.protect(sample)}\")","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}