{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install dytr" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "KF_5u6BmOzP_", "outputId": "8253e48b-1699-4379-e930-614fb494fb67" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting dytr\n", " Downloading dytr-0.1.0-py3-none-any.whl.metadata (14 kB)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.12/dist-packages (from dytr) (2.10.0+cu128)\n", "Requirement already satisfied: numpy>=1.19.0 in /usr/local/lib/python3.12/dist-packages (from dytr) (2.0.2)\n", "Requirement already satisfied: pandas>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from dytr) (2.2.2)\n", "Requirement already satisfied: scikit-learn>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from dytr) (1.6.1)\n", "Requirement already satisfied: tqdm>=4.62.0 in /usr/local/lib/python3.12/dist-packages (from dytr) (4.67.3)\n", "Requirement already satisfied: requests>=2.25.0 in /usr/local/lib/python3.12/dist-packages (from dytr) (2.32.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.3.0->dytr) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.3.0->dytr) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas>=1.3.0->dytr) (2025.3)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->dytr) (3.4.6)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->dytr) (3.11)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->dytr) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.25.0->dytr) (2026.2.25)\n", "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.24.0->dytr) (1.16.3)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.24.0->dytr) (1.5.3)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn>=0.24.0->dytr) (3.6.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (3.25.2)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (4.15.0)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (75.2.0)\n", "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (1.14.0)\n", "Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (3.6.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (3.1.6)\n", "Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (2025.3.0)\n", "Requirement already satisfied: cuda-bindings==12.9.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.9.4)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.8.93)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.8.90)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.8.90)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (9.10.2.21)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.8.4.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (11.3.3.83)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (10.3.9.90)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (11.7.3.90)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.5.8.93)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (0.7.1)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (2.27.5)\n", "Requirement already satisfied: nvidia-nvshmem-cu12==3.4.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (3.4.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.8.90)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (12.8.93)\n", "Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (1.13.1.3)\n", "Requirement already satisfied: triton==3.6.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.10.0->dytr) (3.6.0)\n", "Requirement already satisfied: cuda-pathfinder~=1.1 in /usr/local/lib/python3.12/dist-packages (from cuda-bindings==12.9.4->torch>=1.10.0->dytr) (1.4.3)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas>=1.3.0->dytr) (1.17.0)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.10.0->dytr) (1.3.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.10.0->dytr) (3.0.3)\n", "Downloading dytr-0.1.0-py3-none-any.whl (73 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.4/73.4 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: dytr\n", "Successfully installed dytr-0.1.0\n" ] } ] }, { "cell_type": "code", "source": [ "from dytr import DynamicTransformer, PretrainedModelLoader, ModelConfig, TaskConfig, TrainingStrategy, Trainer, SingleDatasetProcessing\n", "import pandas as pd\n", "\n", "#model_name='prajjwal1/bert-tiny'\n", "model_name='asafaya/bert-mini-arabic'\n", "\n", "\n", "# 1. Configure your transformer\n", "config = ModelConfig(\n", " embed_dim=256,#this will be changed automatically for finetune model\n", " num_layers=6,#\n", " num_heads=8,#\n", " max_seq_len=256,\n", " tokenizer_name = model_name,\n", " use_simple_tokenizer =True,\n", " special_tokens ={},\n", " per_device_train_batch_size=64,\n", " num_train_epochs=5,\n", " per_device_eval_batch_size=16,\n", "\n", ")" ], "metadata": { "id": "fd2NAbWDOxyt" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "#model loading" ], "metadata": { "id": "UPAxCwJhWDzu" } }, { "cell_type": "code", "source": [ "# pretrained model\n", "\n", "loader = PretrainedModelLoader()\n", "\n", "## to confirm the model name is supported or not before loading,\n", "info = loader.get_model_info(model_name)\n", "print(f\"\\nModel info: {info}\")\n", "\n", "## model loading\n", "model=loader.load_pretrained(model_name,config)\n", "\n", "# to build model from scratch use the following line instead\n", "#model=DynamicTransformer(config)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HAC_8VPCWIpk", "outputId": "44e57871-aceb-41d4-e999-fb56433fb3b9" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "Model info: {'model_name': 'asafaya/bert-mini-arabic', 'model_type': 'bert', 'supported': True, 'architecture': {'hidden_size': 256, 'num_layers': 4, 'num_heads': 4, 'vocab_size': 32000, 'max_position_embeddings': 512}}\n", "\n", "============================================================\n", "Loading BERT model: asafaya/bert-mini-arabic\n", "============================================================\n", "Downloading config.json...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "config.json: 100%|██████████| 509/509 [00:00<00:00, 2.93MB/s]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Downloading pytorch_model.bin...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\n", "pytorch_model.bin: 100%|██████████| 46.6M/46.6M [00:00<00:00, 233MB/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Downloading vocab.txt...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "vocab.txt: 334kB [00:00, 29.1MB/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "BERT config:\n", " Hidden size: 256\n", " Layers: 4\n", " Attention heads: 4\n", " Max position embeddings: 512\n", " Vocabulary size: 32000\n", "\n", "Initializing DynamicTransformer...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Downloading vocab.txt: | | 326k/0.00 [00:00<00:00, 29.7MB/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "************************************************************\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Downloading special_tokens_map.json: 100%|██████████| 112/112 [00:00<00:00, 448kB/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Tokenizer loaded with vocab size: 32000\n", "\n", "Loading model weights...\n", " Mapped word embeddings: shape torch.Size([32000, 256])\n", " Mapped layer 0\n", " Mapped layer 1\n", " Mapped layer 2\n", " Mapped layer 3\n", " Initialized final layer norm with identity\n", "\n", "✓ Successfully loaded BERT model as encoder\n", " Encoder parameters: 12,404,224\n", " Total model parameters: 12,404,224\n", " Embed dim: 256\n", " Layers: 4\n", " Heads: 4\n", " Vocabulary size: 32000\n", "\n", "📝 Note: This model has no Tasks: Add tasks using model.add_task() or Train the model on different Tasks to be added\n" ] } ] }, { "cell_type": "code", "source": [ "\n", "# congrats the pretrained model now became dytr architecture\n", "model" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Kvx0showXqpm", "outputId": "972c1425-41a8-4c1a-dd70-bb7ae34a62aa" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DynamicTransformer(\n", " (shared_embedding): Embedding(32000, 256, padding_idx=0)\n", " (encoder): TransformerEncoder(\n", " (embedding): Embedding(32000, 256, padding_idx=0)\n", " (layers): ModuleList(\n", " (0-3): 4 x EncoderLayer(\n", " (attention): MultiHeadAttention(\n", " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (attention_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (ffn_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (ffn): FeedForward(\n", " (gate_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (up_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (down_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (activation): GELU(approximate='none')\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (decoders): ModuleDict()\n", " (task_heads): ModuleDict()\n", ")" ] }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "code", "source": [ "tokenizer =model.tokenizer\n", "print(tokenizer.encode(\" السلام عليكم\"))" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gLixQBprXy5D", "outputId": "f8111538-3fcb-495d-8b6d-958df66a1ec5" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'input_ids': [2675, 4240], 'attention_mask': [1, 1]}\n" ] } ] }, { "cell_type": "markdown", "source": [ "#load your dataset" ], "metadata": { "id": "DKoDQCrRS7ZM" } }, { "cell_type": "code", "source": [ "from datasets import load_dataset\n", "dataset_tr = load_dataset('alsubari/arabic-grammar-errors', split='validation')\n", "dataset_tr=pd.DataFrame(dataset_tr)\n", "dataset_tr.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 920 }, "id": "0cEqlbonS2qB", "outputId": "0986b0c6-2cff-4d7e-c372-299324c32937" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text \\\n", "0 مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع... \n", "1 وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف... \n", "2 خلال اقتراح بعض الحلول المناسبة للمشكلات والصع... \n", "3 في الحديث الصحيح مع النبي عليه وعلى آله وصحبه ... \n", "4 كيلومترات تكون بداية الانطلاقة من أمام مستشفى ... \n", "\n", " correct_text label \\\n", "0 مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع... 3 \n", "1 وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف... 3 \n", "2 خلال اقتراح بعض الحلول المناسبة للمشكلات والصع... 6 \n", "3 في الحديث الصحيح عن النبي عليه وعلى آله وصحبه ... 6 \n", "4 كيلومترات تكون بداية الانطلاقة من أمام مستشفى ... 6 \n", "\n", " tags \n", "0 0 0 0 0 0 0 0 0 0 0 0 0 237 0 0 36 0 0 0 0 0 0... \n", "1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 95 0 0 0 0 0 0 0... \n", "2 0 0 0 0 0 0 0 0 0 0 0 0 115 0 0 0 0 0 0 0 0 0 ... \n", "3 0 0 0 105 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... \n", "4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textcorrect_textlabeltags
0مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع...مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع...30 0 0 0 0 0 0 0 0 0 0 0 237 0 0 36 0 0 0 0 0 0...
1وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف...وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف...30 0 0 0 0 0 0 0 0 0 0 0 0 0 0 95 0 0 0 0 0 0 0...
2خلال اقتراح بعض الحلول المناسبة للمشكلات والصع...خلال اقتراح بعض الحلول المناسبة للمشكلات والصع...60 0 0 0 0 0 0 0 0 0 0 0 115 0 0 0 0 0 0 0 0 0 ...
3في الحديث الصحيح مع النبي عليه وعلى آله وصحبه ...في الحديث الصحيح عن النبي عليه وعلى آله وصحبه ...60 0 0 105 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
4كيلومترات تكون بداية الانطلاقة من أمام مستشفى ...كيلومترات تكون بداية الانطلاقة من أمام مستشفى ...60 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "dataset_tr", "summary": "{\n \"name\": \"dataset_tr\",\n \"rows\": 16212,\n \"fields\": [\n {\n \"column\": \"text\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 16189,\n \"samples\": [\n \"( \\u0627\\u0644\\u0648\\u0637\\u0646 ) \\u0627\\u0644\\u062a\\u0642\\u062a \\u0628\\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0645\\u0646 \\u0627\\u0644\\u0627\\u0637\\u0641\\u0627\\u0644 \\u0641\\u064a \\u0627\\u0644\\u0639\\u064a\\u0648\\u062f \\u0644\\u062a\\u0623\\u062e\\u0630 \\u0622\\u0631\\u0627\\u0621\\u0647\\u0645 \\u0627\\u0646\\u0637\\u0628\\u0627\\u0639\\u0627\\u0646 \\u0639\\u0646 \\u0627\\u0644\\u0639\\u064a\\u0648\\u062f \\u0641\\u0642\\u0627\\u0644\\u062a \\u0645\\u0631\\u064a\\u0645 \\u0628\\u0646\\u064a\\u062a\\u0645\\u0627 \\u0628\\u0634\\u064a\\u0631 \\u0627\\u0644\\u0631\\u064a\\u0627\\u0645\\u064a : \\u0627\\u0646 \\u0627\\u0644\\u0639\\u064a\\u0648\\u062f \\u0644\\u0647 \\u062a\\u0637\\u0639\\u0645\\u0627\\u0646 \\u062e\\u0627\\u0635 \\u0648\\u0646\\u062d\\u0646 \\u064a\\u062d\\u0631\\u0635 \\u0643\\u062b\\u064a\\u0631\\u0627 \\u0639\\u0644\\u0649 \\u062d\\u0636\\u0648\\u0631 .\",\n \"\\u064a\\u0627 \\u0647\\u0644\\u0627 \\u0648\\u062d\\u064a\\u0627\\u0643\\u0645 \\u0645\\u0646 \\u0627\\u062c\\u0645\\u0644 \\u0627\\u0644\\u0627\\u0645\\u0633\\u064a\\u0627\\u062a \\u0641\\u064a \\u0645\\u0647\\u0631\\u062c\\u0627\\u0646 \\u0645\\u0633\\u0642\\u0637 \\u0627\\u0644\\u0627\\u0645\\u0633\\u064a\\u0627\\u062a \\u0627\\u0644\\u0634\\u0639\\u0631\\u064a\\u0629 \\u0627\\u0644\\u062a\\u064a \\u0627\\u0633\\u062a\\u0642\\u0637\\u0628\\u062a \\u062c\\u0645\\u0627\\u0647\\u064a\\u0631 \\u063a\\u0641\\u064a\\u0631\\u0629 \\u062a\\u0641\\u0627\\u0639\\u0644\\u062a \\u0628\\u062d\\u0645\\u0627\\u0633 \\u0643\\u064a \\u0627\\u0644\\u0634\\u0639\\u0631\\u0627\\u0621 \\u0627\\u0644\\u0634\\u0639\\u0631\\u0627\\u0621 \\u062d\\u0636\\u0631\\u0648\\u0627 \\u0628\\u0642\\u0648\\u0629 \\u0643\\u0644\\u0645\\u0627\\u062a\\u0647\\u0645 \\u0627\\u0644\\u0634\\u0639\\u0631\\u064a\\u0629 \\u0627\\u0644\\u062c\\u0645\\u064a\\u0644\\u0629 \\u0648\\u0645\\u0646 \\u062c\\u0627\\u0646\\u0628\\u0647 \\u0627\\u0644\\u062d\\u0636\\u0648\\u0631 \\u0627\\u0644\\u0645\\u062a\\u0645\\u064a\\u0632 \\u0627\\u0633\\u062a\\u0647\\u0648\\u0649 .\",\n \"2 \\u0627\\u0646\\u062a\\u0634\\u0627\\u0631 \\u0627\\u0645\\u062a\\u0644\\u0627\\u0643 \\u0627\\u0644\\u0627\\u0635\\u0648\\u0644 \\u0627\\u0644\\u0645\\u0627\\u0644\\u064a\\u0629 \\u0627\\u0644\\u063a\\u0631\\u0636 \\u0627\\u0644\\u0627\\u0633\\u0627\\u0633\\u064a \\u0645\\u0646 \\u0627\\u0644\\u062a\\u0648\\u0631\\u064a\\u0642 \\u0647\\u0648 \\u062a\\u0648\\u0632\\u064a\\u0639 \\u0627\\u0644\\u0627\\u0635\\u0648\\u0644 \\u0627\\u0644\\u0645\\u0627\\u0644\\u064a\\u0629 \\u062b\\u0645 \\u0627\\u0643\\u0628\\u0631 \\u0639\\u062f\\u062f \\u0645\\u0645\\u0643\\u0646 \\u0645\\u0646 \\u0627\\u0644\\u0645\\u062f\\u062e\\u0631\\u064a\\u0646 .\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"correct_text\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 15257,\n \"samples\": [\n \"18 \\u0627\\u0644\\u0649 25 \\u064a\\u0648\\u0644\\u064a\\u0648 \\u0627\\u0644\\u062d\\u0627\\u0644\\u064a \\u0648\\u0645\\u0634\\u0627\\u0631\\u0643\\u062a\\u0647\\u0627 \\u0627\\u064a\\u0636\\u0627 \\u0641\\u064a \\u0627\\u0644\\u0628\\u0637\\u0648\\u0644\\u062a\\u064a\\u0646 \\u0627\\u0644\\u062f\\u0648\\u0644\\u064a\\u062a\\u064a\\u0646 \\u0644\\u0644\\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0627\\u0644\\u0631\\u0627\\u0628\\u0639\\u0629 \\u0628\\u062f\\u0645\\u0634\\u0642 \\u0645\\u0646 9 \\u0627\\u0644\\u0649 21 \\u0627\\u063a\\u0633\\u0637\\u0633 \\u0627\\u0644\\u0642\\u0627\\u062f\\u0645 \\u0648\\u0627\\u0644\\u0628\\u0637\\u0648\\u0644\\u0629 \\u0627\\u0644\\u062f\\u0648\\u0644\\u064a\\u0629 \\u0644\\u0644\\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0627\\u0644\\u062e\\u0627\\u0645\\u0633\\u0629 \\u0627\\u0644\\u0645\\u0642\\u0631\\u0631 \\u0627\\u0642\\u0627\\u0645\\u062a\\u0647\\u0627 \\u0641\\u064a \\u0644\\u0628\\u0646\\u0627\\u0646 \\u062e\\u0644\\u0627\\u0644 \\u0627\\u0644\\u0641\\u062a\\u0631\\u0629 \\u0645\\u0646 .\",\n \"\\u0648\\u0645\\u0639 \\u0628\\u062f\\u0627\\u064a\\u0627\\u062a \\u0627\\u0644\\u062b\\u0645\\u0627\\u0646\\u064a\\u0646\\u064a\\u0627\\u062a\\u060c \\u062a\\u0642\\u0644\\u0635 \\u0627\\u0644\\u0627\\u0647\\u062a\\u0645\\u0627\\u0645 \\u0627\\u0644\\u0631\\u0633\\u0645\\u064a \\u0641\\u064a \\u0627\\u0644\\u0635\\u064a\\u0646 \\u0628\\u0627\\u0644\\u0645\\u0648\\u0633\\u064a\\u0642\\u0649 \\u0627\\u0644\\u0634\\u0639\\u0628\\u064a\\u0629 \\u0628\\u0639\\u062f \\u0627\\u0644\\u062a\\u062d\\u0648\\u0644 \\u0625\\u0644\\u0649 \\u0633\\u064a\\u0627\\u0633\\u0629 \\u0627\\u0644\\u0627\\u0642\\u062a\\u0635\\u0627\\u062f \\u0627\\u0644\\u062d\\u0631 .\",\n \"\\u0641\\u064a \\u0627\\u0644\\u0645\\u0631\\u0643\\u0632 \\u0627\\u0644\\u062b\\u0627\\u0646\\u064a \\u0628\\u0639\\u062f \\u0633\\u062f\\u0627\\u0628 \\u0641\\u064a \\u0627\\u0644\\u062f\\u0648\\u0631 \\u0627\\u0644\\u0627\\u0648\\u0644 \\u0644\\u062f\\u0648\\u0631\\u064a \\u0639\\u0627\\u0645 \\u0643\\u0631\\u0629 \\u0627\\u0644\\u064a\\u062f \\u0648\\u0647\\u0648 \\u064a\\u0645\\u0644\\u0643 \\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0642\\u0648\\u064a\\u0629 \\u0645\\u0646 \\u0627\\u0644\\u0644\\u0627\\u0639\\u0628\\u064a\\u0646 \\u062a\\u062d\\u062a \\u0627\\u0634\\u0631\\u0627\\u0641 \\u0646\\u0628\\u064a\\u0644 \\u0627\\u0644\\u0645\\u0639\\u0634\\u0631\\u064a \\u0627\\u0644\\u0630\\u064a \\u0627\\u0643\\u062f \\u0639\\u0644\\u0649 \\u0627\\u0647\\u0645\\u064a\\u0629 \\u0627\\u0644\\u062e\\u0631\\u0648\\u062c \\u0628\\u0646\\u062a\\u064a\\u062c\\u0629 \\u0637\\u064a\\u0628\\u0629 \\u0641\\u064a .\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"label\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 7,\n \"num_unique_values\": 8,\n \"samples\": [\n 6,\n 0,\n 3\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"tags\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 12723,\n \"samples\": [\n \"0 0 0 0 0 90 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\",\n \"0 0 0 0 0 0 0 0 0 0 113 0 0 0 0 0 0 0 0 0\",\n \"0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 38 0 0 0\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 63 } ] }, { "cell_type": "code", "source": [ "\n", "# we make the tags labels as binary either 0,1 to avoid complexity and test the model\n", "dataset_tr['tags'] = dataset_tr['tags'].str.split().apply(lambda tags: ' '.join('1' if int(t) != 0 else '0' for t in tags))\n", "dataset_tr.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 920 }, "id": "-BjiZYU2aFCi", "outputId": "71e98843-a395-4aff-d905-2d8bf2a7f251" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " text \\\n", "0 مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع... \n", "1 وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف... \n", "2 خلال اقتراح بعض الحلول المناسبة للمشكلات والصع... \n", "3 في الحديث الصحيح مع النبي عليه وعلى آله وصحبه ... \n", "4 كيلومترات تكون بداية الانطلاقة من أمام مستشفى ... \n", "\n", " correct_text label \\\n", "0 مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع... 3 \n", "1 وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف... 3 \n", "2 خلال اقتراح بعض الحلول المناسبة للمشكلات والصع... 6 \n", "3 في الحديث الصحيح عن النبي عليه وعلى آله وصحبه ... 6 \n", "4 كيلومترات تكون بداية الانطلاقة من أمام مستشفى ... 6 \n", "\n", " tags \n", "0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 ... \n", "1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 ... \n", "2 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 ... \n", "3 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... \n", "4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ... " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textcorrect_textlabeltags
0مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع...مما يؤكل منه قال لهما ( ما أصبتما من أخيكما أع...30 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 ...
1وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف...وتسعى السلطنة لتشجيع وتطوير هذا القطاع ليتيح ف...30 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 ...
2خلال اقتراح بعض الحلول المناسبة للمشكلات والصع...خلال اقتراح بعض الحلول المناسبة للمشكلات والصع...60 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 ...
3في الحديث الصحيح مع النبي عليه وعلى آله وصحبه ...في الحديث الصحيح عن النبي عليه وعلى آله وصحبه ...60 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
4كيلومترات تكون بداية الانطلاقة من أمام مستشفى ...كيلومترات تكون بداية الانطلاقة من أمام مستشفى ...60 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "dataset_tr", "summary": "{\n \"name\": \"dataset_tr\",\n \"rows\": 16212,\n \"fields\": [\n {\n \"column\": \"text\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 16189,\n \"samples\": [\n \"( \\u0627\\u0644\\u0648\\u0637\\u0646 ) \\u0627\\u0644\\u062a\\u0642\\u062a \\u0628\\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0645\\u0646 \\u0627\\u0644\\u0627\\u0637\\u0641\\u0627\\u0644 \\u0641\\u064a \\u0627\\u0644\\u0639\\u064a\\u0648\\u062f \\u0644\\u062a\\u0623\\u062e\\u0630 \\u0622\\u0631\\u0627\\u0621\\u0647\\u0645 \\u0627\\u0646\\u0637\\u0628\\u0627\\u0639\\u0627\\u0646 \\u0639\\u0646 \\u0627\\u0644\\u0639\\u064a\\u0648\\u062f \\u0641\\u0642\\u0627\\u0644\\u062a \\u0645\\u0631\\u064a\\u0645 \\u0628\\u0646\\u064a\\u062a\\u0645\\u0627 \\u0628\\u0634\\u064a\\u0631 \\u0627\\u0644\\u0631\\u064a\\u0627\\u0645\\u064a : \\u0627\\u0646 \\u0627\\u0644\\u0639\\u064a\\u0648\\u062f \\u0644\\u0647 \\u062a\\u0637\\u0639\\u0645\\u0627\\u0646 \\u062e\\u0627\\u0635 \\u0648\\u0646\\u062d\\u0646 \\u064a\\u062d\\u0631\\u0635 \\u0643\\u062b\\u064a\\u0631\\u0627 \\u0639\\u0644\\u0649 \\u062d\\u0636\\u0648\\u0631 .\",\n \"\\u064a\\u0627 \\u0647\\u0644\\u0627 \\u0648\\u062d\\u064a\\u0627\\u0643\\u0645 \\u0645\\u0646 \\u0627\\u062c\\u0645\\u0644 \\u0627\\u0644\\u0627\\u0645\\u0633\\u064a\\u0627\\u062a \\u0641\\u064a \\u0645\\u0647\\u0631\\u062c\\u0627\\u0646 \\u0645\\u0633\\u0642\\u0637 \\u0627\\u0644\\u0627\\u0645\\u0633\\u064a\\u0627\\u062a \\u0627\\u0644\\u0634\\u0639\\u0631\\u064a\\u0629 \\u0627\\u0644\\u062a\\u064a \\u0627\\u0633\\u062a\\u0642\\u0637\\u0628\\u062a \\u062c\\u0645\\u0627\\u0647\\u064a\\u0631 \\u063a\\u0641\\u064a\\u0631\\u0629 \\u062a\\u0641\\u0627\\u0639\\u0644\\u062a \\u0628\\u062d\\u0645\\u0627\\u0633 \\u0643\\u064a \\u0627\\u0644\\u0634\\u0639\\u0631\\u0627\\u0621 \\u0627\\u0644\\u0634\\u0639\\u0631\\u0627\\u0621 \\u062d\\u0636\\u0631\\u0648\\u0627 \\u0628\\u0642\\u0648\\u0629 \\u0643\\u0644\\u0645\\u0627\\u062a\\u0647\\u0645 \\u0627\\u0644\\u0634\\u0639\\u0631\\u064a\\u0629 \\u0627\\u0644\\u062c\\u0645\\u064a\\u0644\\u0629 \\u0648\\u0645\\u0646 \\u062c\\u0627\\u0646\\u0628\\u0647 \\u0627\\u0644\\u062d\\u0636\\u0648\\u0631 \\u0627\\u0644\\u0645\\u062a\\u0645\\u064a\\u0632 \\u0627\\u0633\\u062a\\u0647\\u0648\\u0649 .\",\n \"2 \\u0627\\u0646\\u062a\\u0634\\u0627\\u0631 \\u0627\\u0645\\u062a\\u0644\\u0627\\u0643 \\u0627\\u0644\\u0627\\u0635\\u0648\\u0644 \\u0627\\u0644\\u0645\\u0627\\u0644\\u064a\\u0629 \\u0627\\u0644\\u063a\\u0631\\u0636 \\u0627\\u0644\\u0627\\u0633\\u0627\\u0633\\u064a \\u0645\\u0646 \\u0627\\u0644\\u062a\\u0648\\u0631\\u064a\\u0642 \\u0647\\u0648 \\u062a\\u0648\\u0632\\u064a\\u0639 \\u0627\\u0644\\u0627\\u0635\\u0648\\u0644 \\u0627\\u0644\\u0645\\u0627\\u0644\\u064a\\u0629 \\u062b\\u0645 \\u0627\\u0643\\u0628\\u0631 \\u0639\\u062f\\u062f \\u0645\\u0645\\u0643\\u0646 \\u0645\\u0646 \\u0627\\u0644\\u0645\\u062f\\u062e\\u0631\\u064a\\u0646 .\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"correct_text\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 15257,\n \"samples\": [\n \"18 \\u0627\\u0644\\u0649 25 \\u064a\\u0648\\u0644\\u064a\\u0648 \\u0627\\u0644\\u062d\\u0627\\u0644\\u064a \\u0648\\u0645\\u0634\\u0627\\u0631\\u0643\\u062a\\u0647\\u0627 \\u0627\\u064a\\u0636\\u0627 \\u0641\\u064a \\u0627\\u0644\\u0628\\u0637\\u0648\\u0644\\u062a\\u064a\\u0646 \\u0627\\u0644\\u062f\\u0648\\u0644\\u064a\\u062a\\u064a\\u0646 \\u0644\\u0644\\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0627\\u0644\\u0631\\u0627\\u0628\\u0639\\u0629 \\u0628\\u062f\\u0645\\u0634\\u0642 \\u0645\\u0646 9 \\u0627\\u0644\\u0649 21 \\u0627\\u063a\\u0633\\u0637\\u0633 \\u0627\\u0644\\u0642\\u0627\\u062f\\u0645 \\u0648\\u0627\\u0644\\u0628\\u0637\\u0648\\u0644\\u0629 \\u0627\\u0644\\u062f\\u0648\\u0644\\u064a\\u0629 \\u0644\\u0644\\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0627\\u0644\\u062e\\u0627\\u0645\\u0633\\u0629 \\u0627\\u0644\\u0645\\u0642\\u0631\\u0631 \\u0627\\u0642\\u0627\\u0645\\u062a\\u0647\\u0627 \\u0641\\u064a \\u0644\\u0628\\u0646\\u0627\\u0646 \\u062e\\u0644\\u0627\\u0644 \\u0627\\u0644\\u0641\\u062a\\u0631\\u0629 \\u0645\\u0646 .\",\n \"\\u0648\\u0645\\u0639 \\u0628\\u062f\\u0627\\u064a\\u0627\\u062a \\u0627\\u0644\\u062b\\u0645\\u0627\\u0646\\u064a\\u0646\\u064a\\u0627\\u062a\\u060c \\u062a\\u0642\\u0644\\u0635 \\u0627\\u0644\\u0627\\u0647\\u062a\\u0645\\u0627\\u0645 \\u0627\\u0644\\u0631\\u0633\\u0645\\u064a \\u0641\\u064a \\u0627\\u0644\\u0635\\u064a\\u0646 \\u0628\\u0627\\u0644\\u0645\\u0648\\u0633\\u064a\\u0642\\u0649 \\u0627\\u0644\\u0634\\u0639\\u0628\\u064a\\u0629 \\u0628\\u0639\\u062f \\u0627\\u0644\\u062a\\u062d\\u0648\\u0644 \\u0625\\u0644\\u0649 \\u0633\\u064a\\u0627\\u0633\\u0629 \\u0627\\u0644\\u0627\\u0642\\u062a\\u0635\\u0627\\u062f \\u0627\\u0644\\u062d\\u0631 .\",\n \"\\u0641\\u064a \\u0627\\u0644\\u0645\\u0631\\u0643\\u0632 \\u0627\\u0644\\u062b\\u0627\\u0646\\u064a \\u0628\\u0639\\u062f \\u0633\\u062f\\u0627\\u0628 \\u0641\\u064a \\u0627\\u0644\\u062f\\u0648\\u0631 \\u0627\\u0644\\u0627\\u0648\\u0644 \\u0644\\u062f\\u0648\\u0631\\u064a \\u0639\\u0627\\u0645 \\u0643\\u0631\\u0629 \\u0627\\u0644\\u064a\\u062f \\u0648\\u0647\\u0648 \\u064a\\u0645\\u0644\\u0643 \\u0645\\u062c\\u0645\\u0648\\u0639\\u0629 \\u0642\\u0648\\u064a\\u0629 \\u0645\\u0646 \\u0627\\u0644\\u0644\\u0627\\u0639\\u0628\\u064a\\u0646 \\u062a\\u062d\\u062a \\u0627\\u0634\\u0631\\u0627\\u0641 \\u0646\\u0628\\u064a\\u0644 \\u0627\\u0644\\u0645\\u0639\\u0634\\u0631\\u064a \\u0627\\u0644\\u0630\\u064a \\u0627\\u0643\\u062f \\u0639\\u0644\\u0649 \\u0627\\u0647\\u0645\\u064a\\u0629 \\u0627\\u0644\\u062e\\u0631\\u0648\\u062c \\u0628\\u0646\\u062a\\u064a\\u062c\\u0629 \\u0637\\u064a\\u0628\\u0629 \\u0641\\u064a .\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"label\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 7,\n \"num_unique_values\": 8,\n \"samples\": [\n 6,\n 0,\n 3\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"tags\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2756,\n \"samples\": [\n \"0 0 0 0 0 0 0 0 1 1 0 0\",\n \"0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0\",\n \"0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 64 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "N4bPEakuY4Z6" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Data processing" ], "metadata": { "id": "9S5QJ0KWUTk9" } }, { "cell_type": "code", "source": [ "\n", "from sklearn.model_selection import train_test_split\n", "\n", "dataset_tr, dataset_val = train_test_split(dataset_tr, test_size=0.1, random_state=42,stratify=dataset_tr['label'])" ], "metadata": { "id": "nJW0juw8UQYQ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Toke classification data" ], "metadata": { "id": "MDjXnWFAftOz" } }, { "cell_type": "code", "source": [ "#Token Classificatiin Data\n", "train_token = SingleDatasetProcessing(\n", " dataset_tr,\n", " tokenizer,\n", " 256,# max_length\n", " \"error_detection\",# task_name\n", " TrainingStrategy.TOKEN_CLASSIFICATION,\n", " text_column='text',\n", " tags_column='tags',\n", " token_labeling_first_only=False)\n", "val_token = SingleDatasetProcessing(dataset_val, tokenizer, 256,\"error_detection\", TrainingStrategy.TOKEN_CLASSIFICATION,text_column='text', tags_column='tags',token_labeling_first_only=False,label_to_ids=train_token.label_to_ids)\n", "error_detection_config=TaskConfig(\n", " task_name=\"error_detection\",\n", " training_strategy=TrainingStrategy.TOKEN_CLASSIFICATION,\n", " datasets=[],\n", " num_labels=train_token.num_labels,\n", " max_length=256\n", " )" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PfRJJWLlYHVE", "outputId": "bb89ba2e-b26a-4cde-fb95-afc6e9473d69" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Sample text: متطلعا ثم البحر والغمام الذي يعانقه ويختلط لونه مع لونه تقف و ( الكوس ) الباردة تهب عليك ....\n", " Sample tags: [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]...\n", " Created 14590 samples for error_detection >> token_classification\n", " Sample text: الله صلى الله عليه وسلم حتى مات، ومع آبيين بكر رضي الله عنه حتى مات، ومع عمر رضي الله عنه، فنحن نغزو عنك، فأبى فجهزوه فركب البحر فمات، فلم يجدوا ....\n", " Sample tags: [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]...\n", " Created 1622 samples for error_detection >> token_classification\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Multi dataset\n", "lets add different datasets with different strategies" ], "metadata": { "id": "Lk4VvBMXqvsW" } }, { "cell_type": "code", "source": [ "# lets add different data for other tasks\n", "\n", "## generation data with Question Answering\n", "from dytr import process_qa_dataset\n", "\n", "dataset_tr_qa=load_dataset('FreedomIntelligence/Code-Alpaca-Arabic-GPT4',split='train')\n", "\n", "df_qa=process_qa_dataset(dataset_tr_qa,model.config,conversations_col=\"conversations\")\n", "train_df_qa, val_df_qa = train_test_split(df_qa, test_size=0.01, random_state=42)\n", "train_causal = SingleDatasetProcessing(\n", " train_df_qa, tokenizer, 128,\n", " \"generation\", TrainingStrategy.CAUSAL_LM,\n", " text_column='text',#cache_dir=os.path.join(exp_dir,\"dataset_cache\"),\n", " )\n", "\n", "val_causal = SingleDatasetProcessing(\n", " val_df_qa, tokenizer, 128,\n", " \"generation\", TrainingStrategy.CAUSAL_LM,\n", " text_column='text',#cache_dir=os.path.join(exp_dir,\"dataset_cache\"),\n", " )\n", "generation_config=TaskConfig(\n", " task_name=\"generation\",\n", " training_strategy=TrainingStrategy.CAUSAL_LM,\n", " max_length=128,\n", " )\n", "\n", "## aother Ner tasks data\n", "\n", "dataset_val_ner=load_dataset(\"iSemantics/conllpp-ner-ar\",split='test')\n", "dataset_tr_ner=load_dataset(\"iSemantics/conllpp-ner-ar\",split='train')\n", "dataset_val_ner=pd.DataFrame(dataset_val_ner)\n", "dataset_tr_ner=pd.DataFrame(dataset_tr_ner)\n", "dataset_tr_ner['ner_tags'] = dataset_tr_ner['ner_tags'].apply(lambda x: ' '.join(map(str,x)) if isinstance(x, list) else '')\n", "dataset_tr_ner['tokens'] = dataset_tr_ner['tokens'].apply(lambda x: ' '.join(map(str,x)) if isinstance(x, list) else '')\n", "dataset_val_ner['ner_tags'] = dataset_val_ner['ner_tags'].apply(lambda x: ' '.join(map(str,x)) if isinstance(x, list) else '')\n", "dataset_val_ner['tokens'] = dataset_val_ner['tokens'].apply(lambda x: ' '.join(map(str,x)) if isinstance(x, list) else '')\n", "\n", "train_token_ner = SingleDatasetProcessing(\n", " dataset_tr_ner, tokenizer, 256,\n", " \"ner_detection\", TrainingStrategy.TOKEN_CLASSIFICATION,\n", " text_column='tokens', tags_column='ner_tags',\n", " )\n", "val_token_ner = SingleDatasetProcessing(\n", " dataset_val_ner, tokenizer, 256,\n", " \"ner_detection\", TrainingStrategy.TOKEN_CLASSIFICATION,\n", " text_column='tokens', tags_column='ner_tags',label_to_ids=train_token_ner.label_to_ids\n", " )\n", "ner_detection_config=TaskConfig(\n", " task_name=\"ner_detection\",\n", " training_strategy=TrainingStrategy.TOKEN_CLASSIFICATION,\n", " datasets=[],max_length=256,\n", " num_labels=train_token_ner.num_labels,\n", " )\n", "## semantics data , sentence classification\n", "dataset_tr_sem=load_dataset('sepidmnorozy/Arabic_sentiment',split='train')\n", "dataset_val_sem=load_dataset('sepidmnorozy/Arabic_sentiment',split='test')\n", "dataset_tr_sem=pd.DataFrame(dataset_tr_sem)\n", "dataset_val_sem =pd.DataFrame(dataset_val_sem)\n", "train_class_sm = SingleDatasetProcessing(\n", " dataset_tr_sem, tokenizer, 128,\n", " \"sentiment\", TrainingStrategy.SENTENCE_CLASSIFICATION,\n", " text_column='text', label_column='label'\n", " )\n", "val_class_sm = SingleDatasetProcessing(\n", " dataset_val_sem, tokenizer, 128,\n", " \"sentiment\", TrainingStrategy.SENTENCE_CLASSIFICATION,\n", " text_column='text', label_column='label')\n", "sentiment_config=TaskConfig(\n", " task_name=\"sentiment\",\n", " training_strategy=TrainingStrategy.SENTENCE_CLASSIFICATION,\n", " datasets=[],max_length=128,\n", " num_labels=train_class_sm.num_labels,\n", " )" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p1_Ad7pfcOVu", "outputId": "4db5dd53-bff2-420a-c34c-0e3006dc3074" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Sample processed text: Q: أعد الفهرس لأول حدوث لعنصر ما في المصفوفة المعطاة.\n", "arr = [2, 3, 3, 1, 5, 2] \n", " <|answer|> \n", " إذا كنت تبحث عن فهرس أول ظهور لعنصر معين في القائمة، يمكنك استخدام الأمر index في ال Python. \n", "\n", "مثلا، إذا كنت تبحث عن أول وقوع للرقم 3، يمكنك القيام بما يلي:\n", "\n", "index = arr.index(3)\n", "يمكنك استبدل الرقم 3 بأي رقم (أو عنصر) تبحث عنه.\n", "\n", "لذا القطعة البرمجية الكاملة ستكون كما يلي:\n", "\n", "arr = [2, 3, 3, 1, 5, 2]\n", "index = arr.index(3)\n", "print(index)\n", "\n", "عند تشغيل هذا، سيكون الناتج 1 (لأن العنصر 3 يظهر أولاً في الموقع 1 قائمه ال arr). \n", " \n", "...\n", "Total Questions Answers: 20017\n", " Creating causal LM dataset with dynamic window sampling...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Processing documents: 100%|██████████| 19816/19816 [00:12<00:00, 1584.28it/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Do you want to save Causal Data Window in cach dir?[Y/n]\n", " Created and cached 74944 windows from 19816 documents\n", " Created 74944 samples for generation >> causal_lm\n", " Creating causal LM dataset with dynamic window sampling...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Processing documents: 100%|██████████| 201/201 [00:00<00:00, 1613.72it/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Do you want to save Causal Data Window in cach dir?[Y/n]\n", " Created and cached 812 windows from 201 documents\n", " Created 812 samples for generation >> causal_lm\n", " Sample text: الاتحاد الأوروبي يرفض الدعوة الألمانية لمقاطعة لحم الضأن البريطاني ....\n", " Sample tags: [3, 4, 0, 0, 7, 0, 0, 0, 7, 0]...\n", " Created 8926 samples for ner_detection >> token_classification\n", " Sample text: كرة القدم - اليابان تحقق فوزًا محظوظًا، الصين في هزيمة مفاجئة....\n", " Sample tags: [0, 0, 0, 5, 0, 0, 0, 5, 0, 0, 0]...\n", " Created 2202 samples for ner_detection >> token_classification\n", " Sample text: ربك دايما جنبك لو نديته هتلاقيه ...\n", " Sample label: 1\n", " Created 2468 samples for sentiment >> sentence_classification\n", " Sample text: سلامة أحمد سلامة، تركت إرثاً من الأخلاق حين عزت و من الحصافة حين اهتزت و من الشجاعة حين كان لها ثمن يرحمك الله...\n", " Sample label: 1\n", " Created 706 samples for sentiment >> sentence_classification\n" ] } ] }, { "cell_type": "code", "source": [ "train_datasets = {\n", " \"sentiment\": (train_class_sm, TrainingStrategy.SENTENCE_CLASSIFICATION),\n", " \"error_detection\": (train_token, TrainingStrategy.TOKEN_CLASSIFICATION),\n", " \"ner_detection\": (train_token_ner, TrainingStrategy.TOKEN_CLASSIFICATION),\n", " \"generation\": (train_causal, TrainingStrategy.CAUSAL_LM)\n", " }\n", "\n", "val_datasets = {\n", " \"sentiment\": (val_class_sm, TrainingStrategy.SENTENCE_CLASSIFICATION),\n", " \"error_detection\": (val_token, TrainingStrategy.TOKEN_CLASSIFICATION),\n", " \"ner_detection\": (val_token_ner, TrainingStrategy.TOKEN_CLASSIFICATION),\n", " \"generation\": (val_causal, TrainingStrategy.CAUSAL_LM)\n", " }\n", "task_configs_list=[sentiment_config,error_detection_config,ner_detection_config, generation_config]" ], "metadata": { "id": "Lq_2SHi-jyft" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Training on multiple Tasks simultaneously" ], "metadata": { "id": "Ju_yR-i3Sdpi" } }, { "cell_type": "code", "source": [ "model.config" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "px-aR7rJljn7", "outputId": "d7666ed8-8f70-4655-ee2a-95d626e14330" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "ModelConfig(embed_dim=256, num_layers=4, num_heads=4, head_dim=64, ff_mult=4, dropout=0.1, max_seq_len=512, learning_rate=0.0003, batch_size=16, weight_decay=0.01, gradient_clip=1.0, warmup_steps=1000, max_learning_rate=0.0005, min_learning_rate=1e-06, adam_epsilon=1e-08, label_smoothing=0.1, fp16=False, gradient_accumulation_steps=1, max_grad_norm=1.0, patience=3, evaluation_strategy='steps', logging_steps=50, validation_check_interval=500, load_best_model_at_end=True, metric_for_best_model='loss', early_stopping_patience=10, max_train_steps=100000, num_train_epochs=5, lr_scheduler_type='cosine', per_device_train_batch_size=64, per_device_eval_batch_size=16, dataloader_num_workers=2, dataloader_pin_memory=True, seed=42, task_specific_lr={}, task_weights={}, use_rotary_embedding=False, use_flash_attention=False, gradient_checkpointing=False, training_from_scratch=False, special_tokens={}, window_size=256, stride=64, tasks={}, vocab_size=32000, tokenizer_name='asafaya/bert-mini-arabic', add_tab_newline_vocab=False, use_simple_tokenizer=True, tokenizer_type='wordpiece', bos_token_id=2, eos_token_id=3, adapter_bottleneck=64, use_task_adapters=False, ewc_lambda=1000.0, replay_buffer_size=1000, use_ewc=True, use_replay=True, causal_lm_window_size=256, causal_lm_stride=128, head_lr_mult=1.0, decoder_lr_mult=1.0, shared_lr_mult=0.1, device='cuda')" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "Train_config=model.config\n", "Train_config.evaluation_strategy='epoch'\n", "Train_config.num_train_epochs=10\n", "Train_config.head_lr_mult=1.0\n", "Train_config.decoder_lr_mult=2.0 # increased for generation which based on decoder\n", "Train_config.shared_lr_mult=0.2\n", "Train_config.learning_rate=3e-4\n", "Train_config.per_device_train_batch_size=128\n", "Train_config.per_device_eval_batch_size=32" ], "metadata": { "id": "72h0pfuxl5Bn" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "exp_dir='dytr_model'\n", "trainer = Trainer(model, Train_config, exp_dir)\n", "model=trainer.train(task_configs_list, train_datasets, val_datasets)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y7EDQXU3SYVc", "outputId": "16dabdfe-29f8-45d0-fe39-79995a2f8180" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Trainer initialized\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - ============================================================\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Starting training session\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - ============================================================\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Frozen all model parameters\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Unfrozen shared encoder and embeddings\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Unfrozen head for task: sentiment\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Unfrozen head for task: error_detection\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Unfrozen head for task: ner_detection\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Unfrozen decoder for task: generation\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Model parameters after adding tasks:\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Model size: 25009293 parameters\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Trainable: 25009293\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - \n", "Parameter breakdown:\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Head sentiment: 33410\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Head error_detection: 66818\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Head ner_detection: 68617\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Decoder generation: 20628224\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Shared encoder: 12,404,224\n", "2026-04-02 14:52:47 - dytr.training.trainer - INFO - Shared embedding: 8,192,000\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Max lengths per task: {'sentiment': 128, 'error_detection': 256, 'ner_detection': 256, 'generation': 128}\n", "Max lengths per task: {'sentiment': 128, 'error_detection': 256, 'ner_detection': 256, 'generation': 128}\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 1/10: 100%|██████████| 786/786 [05:24<00:00, 2.42it/s, loss=4.5472, tasks={'generation': 4.1654456424713135, 'error_detection': 0.35143900513648985, 'ner_detection': 1.3261612837805468, 'sentiment': 0.6868754344827989}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 1 average training loss: 4.4872\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 14:58:14 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: inf\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.2110\n", " sentiment validation:\n", " accuracy: 0.7898\n", " f1_macro: 0.7807\n", " f1_weighted: 0.7868\n", " f1_micro: 0.7898\n", " error_detection validation:\n", " token_accuracy: 0.9374\n", " f1_macro: 0.5004\n", " f1_weighted: 0.9092\n", " f1_micro: 0.9374\n", " ner_detection validation:\n", " token_accuracy: 0.7956\n", " f1_macro: 0.2122\n", " f1_weighted: 0.7308\n", " f1_micro: 0.7956\n", " generation validation:\n", " perplexity: 58.8416\n", " token_accuracy: 0.4213\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 2/10: 100%|██████████| 786/786 [05:21<00:00, 2.44it/s, loss=2.8689, tasks={'generation': 3.456082937717438, 'error_detection': 0.3213178336620331, 'ner_detection': 0.9571117907762527, 'sentiment': 0.46282851696014404}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 2 average training loss: 2.8627\n", "\n", " Validation loss: 1.0950\n", " sentiment validation:\n", " accuracy: 0.8026\n", " f1_macro: 0.7968\n", " f1_weighted: 0.8029\n", " f1_micro: 0.8026\n", " error_detection validation:\n", " token_accuracy: 0.9397\n", " f1_macro: 0.5589\n", " f1_weighted: 0.9173\n", " f1_micro: 0.9397\n", " ner_detection validation:\n", " token_accuracy: 0.8140\n", " f1_macro: 0.3069\n", " f1_weighted: 0.7764\n", " f1_micro: 0.8140\n", " generation validation:\n", " perplexity: 34.2676\n", " token_accuracy: 0.5246\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:03:39 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 0.11600485552440998\n", "Epoch 3/10: 100%|██████████| 786/786 [05:22<00:00, 2.44it/s, loss=2.5558, tasks={'generation': 3.182457141876221, 'error_detection': 0.30829586982727053, 'ner_detection': 0.8627469592234668, 'sentiment': 0.24829364261206457}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 3 average training loss: 2.5546\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:09:05 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 0.019131840900941333\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.0759\n", " sentiment validation:\n", " accuracy: 0.7898\n", " f1_macro: 0.7839\n", " f1_weighted: 0.7902\n", " f1_micro: 0.7898\n", " error_detection validation:\n", " token_accuracy: 0.9386\n", " f1_macro: 0.5830\n", " f1_weighted: 0.9197\n", " f1_micro: 0.9386\n", " ner_detection validation:\n", " token_accuracy: 0.8175\n", " f1_macro: 0.3223\n", " f1_weighted: 0.7831\n", " f1_micro: 0.8175\n", " generation validation:\n", " perplexity: 28.9015\n", " token_accuracy: 0.5590\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 4/10: 100%|██████████| 786/786 [05:21<00:00, 2.44it/s, loss=2.4043, tasks={'generation': 3.0265496969223022, 'error_detection': 0.29168049663305284, 'ner_detection': 0.785723519675872, 'sentiment': 0.21430100851199207}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 4 average training loss: 2.4046\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:14:30 - dytr.training.trainer - WARNING - No improvement for 1 epochs\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.0782\n", " sentiment validation:\n", " accuracy: 0.7812\n", " f1_macro: 0.7764\n", " f1_weighted: 0.7821\n", " f1_micro: 0.7812\n", " error_detection validation:\n", " token_accuracy: 0.9374\n", " f1_macro: 0.5802\n", " f1_weighted: 0.9189\n", " f1_micro: 0.9374\n", " ner_detection validation:\n", " token_accuracy: 0.8122\n", " f1_macro: 0.3443\n", " f1_weighted: 0.7869\n", " f1_micro: 0.8122\n", " generation validation:\n", " perplexity: 26.8652\n", " token_accuracy: 0.5777\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 5/10: 100%|██████████| 786/786 [05:22<00:00, 2.44it/s, loss=2.3015, tasks={'generation': 2.913034658432007, 'error_detection': 0.2756128540635109, 'ner_detection': 0.7202174944036147, 'sentiment': 0.20325492760714362}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 5 average training loss: 2.3023\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:19:55 - dytr.training.trainer - WARNING - No improvement for 2 epochs\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.0856\n", " sentiment validation:\n", " accuracy: 0.7898\n", " f1_macro: 0.7842\n", " f1_weighted: 0.7902\n", " f1_micro: 0.7898\n", " error_detection validation:\n", " token_accuracy: 0.9345\n", " f1_macro: 0.5928\n", " f1_weighted: 0.9190\n", " f1_micro: 0.9345\n", " ner_detection validation:\n", " token_accuracy: 0.8015\n", " f1_macro: 0.3492\n", " f1_weighted: 0.7830\n", " f1_micro: 0.8015\n", " generation validation:\n", " perplexity: 25.8716\n", " token_accuracy: 0.5886\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 6/10: 100%|██████████| 786/786 [05:22<00:00, 2.44it/s, loss=2.2210, tasks={'generation': 2.8193092226982115, 'error_detection': 0.26240902438759806, 'ner_detection': 0.6710079382447636, 'sentiment': 0.2012844962232253}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 6 average training loss: 2.2222\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:25:20 - dytr.training.trainer - WARNING - No improvement for 3 epochs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " Validation loss: 1.0856\n", " sentiment validation:\n", " accuracy: 0.7784\n", " f1_macro: 0.7723\n", " f1_weighted: 0.7788\n", " f1_micro: 0.7784\n", " error_detection validation:\n", " token_accuracy: 0.9335\n", " f1_macro: 0.5934\n", " f1_weighted: 0.9187\n", " f1_micro: 0.9335\n", " ner_detection validation:\n", " token_accuracy: 0.8059\n", " f1_macro: 0.3617\n", " f1_weighted: 0.7852\n", " f1_micro: 0.8059\n", " generation validation:\n", " perplexity: 25.2586\n", " token_accuracy: 0.5972\n", " Early stopping triggered after 6 Validation\n", "Insert letter Y to stop the training: [Y/n]\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 7/10: 100%|██████████| 786/786 [05:22<00:00, 2.44it/s, loss=2.1543, tasks={'generation': 2.7420612215995788, 'error_detection': 0.2517592230439186, 'ner_detection': 0.6321818495497984, 'sentiment': 0.2011947701959049}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 7 average training loss: 2.1558\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:31:38 - dytr.training.trainer - WARNING - No improvement for 1 epochs\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.0942\n", " sentiment validation:\n", " accuracy: 0.7812\n", " f1_macro: 0.7755\n", " f1_weighted: 0.7817\n", " f1_micro: 0.7812\n", " error_detection validation:\n", " token_accuracy: 0.9320\n", " f1_macro: 0.5905\n", " f1_weighted: 0.9175\n", " f1_micro: 0.9320\n", " ner_detection validation:\n", " token_accuracy: 0.8016\n", " f1_macro: 0.3528\n", " f1_weighted: 0.7841\n", " f1_micro: 0.8016\n", " generation validation:\n", " perplexity: 24.9701\n", " token_accuracy: 0.6015\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 8/10: 100%|██████████| 786/786 [05:21<00:00, 2.44it/s, loss=2.1009, tasks={'generation': 2.6821344566345213, 'error_detection': 0.24392416402697564, 'ner_detection': 0.6065885240540785, 'sentiment': 0.2000615912325242}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 8 average training loss: 2.1027\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:37:03 - dytr.training.trainer - WARNING - No improvement for 2 epochs\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.0993\n", " sentiment validation:\n", " accuracy: 0.7841\n", " f1_macro: 0.7784\n", " f1_weighted: 0.7846\n", " f1_micro: 0.7841\n", " error_detection validation:\n", " token_accuracy: 0.9304\n", " f1_macro: 0.5944\n", " f1_weighted: 0.9173\n", " f1_micro: 0.9304\n", " ner_detection validation:\n", " token_accuracy: 0.7957\n", " f1_macro: 0.3563\n", " f1_weighted: 0.7835\n", " f1_micro: 0.7957\n", " generation validation:\n", " perplexity: 24.6795\n", " token_accuracy: 0.6058\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 9/10: 100%|██████████| 786/786 [05:22<00:00, 2.44it/s, loss=2.0636, tasks={'generation': 2.64249169588089, 'error_detection': 0.23969999939203263, 'ner_detection': 0.5925140827894211, 'sentiment': 0.20039922437247107}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 9 average training loss: 2.0658\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:42:28 - dytr.training.trainer - WARNING - No improvement for 3 epochs\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " Validation loss: 1.0990\n", " sentiment validation:\n", " accuracy: 0.7855\n", " f1_macro: 0.7798\n", " f1_weighted: 0.7860\n", " f1_micro: 0.7855\n", " error_detection validation:\n", " token_accuracy: 0.9302\n", " f1_macro: 0.5974\n", " f1_weighted: 0.9175\n", " f1_micro: 0.9302\n", " ner_detection validation:\n", " token_accuracy: 0.8000\n", " f1_macro: 0.3578\n", " f1_weighted: 0.7853\n", " f1_micro: 0.8000\n", " generation validation:\n", " perplexity: 24.6720\n", " token_accuracy: 0.6076\n", " Early stopping triggered after 9 Validation\n", "Insert letter Y to stop the training: [Y/n]\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 10/10: 100%|██████████| 786/786 [05:22<00:00, 2.43it/s, loss=2.0433, tasks={'generation': 2.624660816192627, 'error_detection': 0.23840540289878845, 'ner_detection': 0.585954545175328, 'sentiment': 0.2000701935852275}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 10 average training loss: 2.0459\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 15:48:13 - dytr.training.trainer - WARNING - No improvement for 1 epochs\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.0993\n", " sentiment validation:\n", " accuracy: 0.7855\n", " f1_macro: 0.7798\n", " f1_weighted: 0.7860\n", " f1_micro: 0.7855\n", " error_detection validation:\n", " token_accuracy: 0.9304\n", " f1_macro: 0.5969\n", " f1_weighted: 0.9176\n", " f1_micro: 0.9304\n", " ner_detection validation:\n", " token_accuracy: 0.7994\n", " f1_macro: 0.3567\n", " f1_weighted: 0.7850\n", " f1_micro: 0.7994\n", " generation validation:\n", " perplexity: 24.6604\n", " token_accuracy: 0.6083\n" ] } ] }, { "cell_type": "code", "source": [ "\n", "#model_name='asafaya/bert-mini-arabic'\n", "model.save_model(\"finetune_bert_mini_arabic_mltitasks_and_generation.pt\")" ], "metadata": { "id": "b3Tj3buY4n3m" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "model.eval()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FS_jj7z26M07", "outputId": "337d14a3-2ac6-4c77-aac4-7e5312299878" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "DynamicTransformer(\n", " (shared_embedding): Embedding(32000, 256, padding_idx=0)\n", " (encoder): TransformerEncoder(\n", " (embedding): Embedding(32000, 256, padding_idx=0)\n", " (layers): ModuleList(\n", " (0-3): 4 x EncoderLayer(\n", " (attention): MultiHeadAttention(\n", " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (attention_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (ffn_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (ffn): FeedForward(\n", " (gate_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (up_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (down_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (activation): GELU(approximate='none')\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (decoders): ModuleDict(\n", " (generation): TransformerDecoder(\n", " (layers): ModuleList(\n", " (0-3): 4 x DecoderLayer(\n", " (self_attention): MultiHeadAttention(\n", " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ffn): FeedForward(\n", " (gate_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (up_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (down_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (activation): GELU(approximate='none')\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (output_proj): Linear(in_features=256, out_features=32000, bias=True)\n", " (embedding): Embedding(32000, 256, padding_idx=0)\n", " )\n", " )\n", " (task_heads): ModuleDict(\n", " (sentiment): Sequential(\n", " (0): Linear(in_features=256, out_features=128, bias=True)\n", " (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=128, out_features=2, bias=True)\n", " )\n", " (error_detection): Sequential(\n", " (0): Linear(in_features=256, out_features=256, bias=True)\n", " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=256, out_features=2, bias=True)\n", " )\n", " (ner_detection): Sequential(\n", " (0): Linear(in_features=256, out_features=256, bias=True)\n", " (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=256, out_features=9, bias=True)\n", " )\n", " )\n", ")" ] }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "markdown", "source": [ "# model inference" ], "metadata": { "id": "4252SVFH6SAH" } }, { "cell_type": "code", "source": [ "\n", "\n", "\n", "inp=\"q: اشرح ما هو الخوارزمية غير القاطعة. \\n <|answer|> \\n\"\n", "model.generate(inp,task_name =\"generation\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 122 }, "id": "Jhg0xvU26YTu", "outputId": "d627ec1f-0ef3-4895-d27c-fd6a00693801" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'q : اشرح ما هو الخوارزمية غير القاطعة . < | a n s w e r | > عدد ال و ل في القا م ة ( ط و ل ) في البرمجة هي 0 ، 1 ) حيث كل عنصر في القا م ة يبد من الصفر حتى تصل ل ى البداية . \" 2 .'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 44 } ] }, { "cell_type": "code", "source": [ "inp='الله صلى الله عليه وسلم حتى مات، ومع آبيين بكر رضي الله عنه حتى مات، ومع عمر رضي الله عنه، فنحن نغزو عنك، فأبى فجهزوه فركب البحر فمات، فلم يجدوا .'\n", "print(\"error_detection \")\n", "model.generate(inp,task_name =\"error_detection\")[\"pairs\"]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2De1uWIj_1Jj", "outputId": "fd4b183e-62fb-42c9-fbfe-9e5fc46e5616" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "error_detection \n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[('الله', 0),\n", " ('صلى', 0),\n", " ('الله', 0),\n", " ('عليه', 0),\n", " ('وسلم', 0),\n", " ('حتى', 0),\n", " ('مات', 0),\n", " ('،', 0),\n", " ('ومع', 0),\n", " ('[UNK]', 0),\n", " ('ب', 0),\n", " ('ي', 1),\n", " ('ي', 1),\n", " ('ن', 1),\n", " ('بكر', 0),\n", " ('رضي', 0),\n", " ('الله', 0),\n", " ('عنه', 0),\n", " ('حتى', 0),\n", " ('مات', 0),\n", " ('،', 0),\n", " ('ومع', 0),\n", " ('عمر', 0),\n", " ('رضي', 0),\n", " ('الله', 0),\n", " ('عنه', 0),\n", " ('،', 0),\n", " ('فنحن', 0),\n", " ('نغ', 0),\n", " ('##زو', 0),\n", " ('عنك', 0),\n", " ('،', 0),\n", " ('ف', 0),\n", " ('[UNK]', 0),\n", " ('ب', 0),\n", " ('ى', 0),\n", " ('فج', 0),\n", " ('##هز', 0),\n", " ('##وه', 0),\n", " ('فرك', 0),\n", " ('##ب', 0),\n", " ('البحر', 0),\n", " ('فمات', 0),\n", " ('،', 0),\n", " ('فلم', 0),\n", " ('يجدوا', 0),\n", " ('.', 0)]" ] }, "metadata": {}, "execution_count": 48 } ] }, { "cell_type": "code", "source": [ "\n", "inp= \" قامة شركة ستارلنك بعمل حملة توعوية بقيادة ايلون ماسك\"\n", "model.generate(inp,task_name =\"ner_detection\")[\"pairs\"]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S9Q3PiGxAps-", "outputId": "4d6b04f3-b00f-489c-9c94-6175d6d773db" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "[('قام', 0),\n", " ('##ة', 0),\n", " ('شركة', 0),\n", " ('ستار', 3),\n", " ('##لن', 4),\n", " ('##ك', 0),\n", " ('بعمل', 0),\n", " ('حملة', 0),\n", " ('تو', 0),\n", " ('##عو', 0),\n", " ('##ية', 0),\n", " ('بقيادة', 0),\n", " ('ايل', 0),\n", " ('##ون', 0),\n", " ('ماسك', 0)]" ] }, "metadata": {}, "execution_count": 52 } ] }, { "cell_type": "code", "source": [ "model.generate(inp,task_name =\"sentiment\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Zfv7JQlCBcar", "outputId": "e38dc4ad-4bc6-4843-8dc3-990660e17851" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "{'prediction': 0,\n", " 'probabilities': [0.9577341675758362, 0.04226585477590561],\n", " 'logits': [1.5720747709274292, -1.548515796661377]}" ] }, "metadata": {}, "execution_count": 53 } ] }, { "cell_type": "markdown", "source": [ "# Model Export" ], "metadata": { "id": "pM-0aEgwBvC0" } }, { "cell_type": "code", "source": [ "exporter = model.get_exporter()\n", "sentiment_model = exporter.export_single_task(\"sentiment\", \"./sentiment_model.pt\")" ], "metadata": { "id": "bSuMrTSfBtel" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "sentiment_model" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4sUQ6LFpCfP2", "outputId": "58f82b13-f58d-433c-fe31-af88e76df76d" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "SingleTaskModel(\n", " (encoder): TransformerEncoder(\n", " (embedding): Embedding(32000, 256, padding_idx=0)\n", " (layers): ModuleList(\n", " (0-3): 4 x EncoderLayer(\n", " (attention): MultiHeadAttention(\n", " (q_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (k_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (v_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (out_proj): Linear(in_features=256, out_features=256, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (attention_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (ffn_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " (ffn): FeedForward(\n", " (gate_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (up_proj): Linear(in_features=256, out_features=1024, bias=True)\n", " (down_proj): Linear(in_features=1024, out_features=256, bias=True)\n", " (activation): GELU(approximate='none')\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (task_head): Sequential(\n", " (0): Linear(in_features=256, out_features=128, bias=True)\n", " (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (2): GELU(approximate='none')\n", " (3): Dropout(p=0.1, inplace=False)\n", " (4): Linear(in_features=128, out_features=2, bias=True)\n", " )\n", ")" ] }, "metadata": {}, "execution_count": 55 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "sW-Nay8WDw-E" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(f\" Total model parameters: {sum(p.numel() for p in model.parameters()):,}\")\n", "print(f\" Total parameters of sentiment_model: {sum(p.numel() for p in sentiment_model.parameters()):,}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DD5aQx5FDAno", "outputId": "cbaaeb0c-d505-4189-fd0d-ce2236583dd6" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Total model parameters: 25,009,293\n", " Total parameters of sentiment_model: 12,437,634\n" ] } ] }, { "cell_type": "markdown", "source": [ "#continue training\n", "here we fine tune the model on generation data or you can add new data with new task configuration, and we put on validation three different tasks include generation," ], "metadata": { "id": "0wqIsVo0DxyY" } }, { "cell_type": "code", "source": [ "\n", "model=DynamicTransformer.load_model(\"finetune_bert_mini_arabic_mltitasks_and_generation.pt\")\n", "\n", "train_datasets = {\n", "\n", " \"generation\": (train_causal, TrainingStrategy.CAUSAL_LM)\n", " }\n", "\n", "val_datasets = {\n", "\n", " \"error_detection\": (val_token, TrainingStrategy.TOKEN_CLASSIFICATION),\n", " \"ner_detection\": (val_token_ner, TrainingStrategy.TOKEN_CLASSIFICATION),\n", " \"generation\": (val_causal, TrainingStrategy.CAUSAL_LM)\n", " }\n", "task_configs_list=[sentiment_config,error_detection_config,ner_detection_config, generation_config]\n", "#task_configs_list=[ generation_config]\n", "Train_config=model.config\n", "Train_config.per_device_train_batch_size=64\n", "Train_config.num_train_epochs=5\n", "Train_config.head_lr_mult=0.05\n", "Train_config.decoder_lr_mult=1.0 # increased for generation which based on decoder\n", "Train_config.shared_lr_mult=0.05\n", "Train_config.learning_rate=3e-4\n", "exp_dir='dytr_model_continue_training'\n", "model.best_val_loss=10\n", "trainer = Trainer(model, Train_config, exp_dir)\n", "model=trainer.train(task_configs_list, train_datasets, val_datasets)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2tKJSaRFD-rA", "outputId": "d7b11330-48ad-4195-c934-e68dcbacd0d0" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "File already exists: downloads/bert_mini_arabic/vocab.txt\n", "************************************************************\n", "File already exists: downloads/bert_mini_arabic/special_tokens_map.json\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Trainer initialized\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - ============================================================\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Starting training session\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - ============================================================\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Frozen all model parameters\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Unfrozen shared encoder and embeddings\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Unfrozen decoder for task: generation\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Model parameters after adding tasks:\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Model size: 25009293 parameters\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Trainable: 24840448\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - \n", "Parameter breakdown:\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Head sentiment: 33410\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Head error_detection: 66818\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Head ner_detection: 68617\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Decoder generation: 20628224\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Shared encoder: 12,404,224\n", "2026-04-02 17:11:05 - dytr.training.trainer - INFO - Shared embedding: 8,192,000\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Max lengths per task: {'sentiment': 128, 'error_detection': 256, 'ner_detection': 256, 'generation': 128}\n", "Max lengths per task: {'sentiment': 128, 'error_detection': 256, 'ner_detection': 256, 'generation': 128}\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 1/5: 100%|██████████| 1171/1171 [05:00<00:00, 3.89it/s, loss=2.6929, tasks={'generation': 2.7645734977722167}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 1 average training loss: 2.6947\n", "\n", " Validation loss: 1.1850\n", " error_detection validation:\n", " token_accuracy: 0.9309\n", " f1_macro: 0.5951\n", " f1_weighted: 0.9176\n", " f1_micro: 0.9309\n", " ner_detection validation:\n", " token_accuracy: 0.7995\n", " f1_macro: 0.3581\n", " f1_weighted: 0.7847\n", " f1_micro: 0.7995\n", " generation validation:\n", " perplexity: 25.8841\n", " token_accuracy: 0.5956\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 17:16:09 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 8.815029585903341\n", "Epoch 2/5: 100%|██████████| 1171/1171 [04:59<00:00, 3.91it/s, loss=2.7364, tasks={'generation': 2.703877909183502}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 2 average training loss: 2.7362\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 17:21:12 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 0.0016153450612421683\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.1834\n", " error_detection validation:\n", " token_accuracy: 0.9308\n", " f1_macro: 0.5895\n", " f1_weighted: 0.9169\n", " f1_micro: 0.9308\n", " ner_detection validation:\n", " token_accuracy: 0.8002\n", " f1_macro: 0.3601\n", " f1_weighted: 0.7845\n", " f1_micro: 0.8002\n", " generation validation:\n", " perplexity: 25.8402\n", " token_accuracy: 0.5994\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 3/5: 100%|██████████| 1171/1171 [04:59<00:00, 3.90it/s, loss=2.6583, tasks={'generation': 2.6166228294372558}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 3 average training loss: 2.6578\n", "\n", " Validation loss: 1.1817\n", " error_detection validation:\n", " token_accuracy: 0.9311\n", " f1_macro: 0.5869\n", " f1_weighted: 0.9167\n", " f1_micro: 0.9311\n", " ner_detection validation:\n", " token_accuracy: 0.7993\n", " f1_macro: 0.3581\n", " f1_weighted: 0.7830\n", " f1_micro: 0.7993\n", " generation validation:\n", " perplexity: 25.7605\n", " token_accuracy: 0.6040\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 17:26:15 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 0.0016998075521907907\n", "Epoch 4/5: 100%|██████████| 1171/1171 [05:00<00:00, 3.90it/s, loss=2.5724, tasks={'generation': 2.53860111951828}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 4 average training loss: 2.5720\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 17:31:18 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 0.0018719680659422533\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.1798\n", " error_detection validation:\n", " token_accuracy: 0.9311\n", " f1_macro: 0.5853\n", " f1_weighted: 0.9165\n", " f1_micro: 0.9311\n", " ner_detection validation:\n", " token_accuracy: 0.8000\n", " f1_macro: 0.3563\n", " f1_weighted: 0.7834\n", " f1_micro: 0.8000\n", " generation validation:\n", " perplexity: 25.5279\n", " token_accuracy: 0.6105\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "Epoch 5/5: 100%|██████████| 1171/1171 [04:59<00:00, 3.90it/s, loss=2.5115, tasks={'generation': 2.4971592569351198}]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Epoch 5 average training loss: 2.5116\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "2026-04-02 17:36:20 - dytr.training.trainer - WARNING - ✓ Best model saved improvement: 0.0001728649322803033\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n", " Validation loss: 1.1796\n", " error_detection validation:\n", " token_accuracy: 0.9312\n", " f1_macro: 0.5851\n", " f1_weighted: 0.9165\n", " f1_micro: 0.9312\n", " ner_detection validation:\n", " token_accuracy: 0.8003\n", " f1_macro: 0.3572\n", " f1_weighted: 0.7838\n", " f1_micro: 0.8003\n", " generation validation:\n", " perplexity: 25.5176\n", " token_accuracy: 0.6103\n" ] } ] } ] }