{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "77b9c4dc-4975-4e4f-b611-3a6d49378a53", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/capstor/scratch/cscs/mkariuki\n" ] } ], "source": [ "%cd /capstor/scratch/cscs/mkariuki" ] }, { "cell_type": "code", "execution_count": 2, "id": "1ae27861-481a-4ce0-92b7-36b43d1ef103", "metadata": {}, "outputs": [], "source": [ "# importing the libraries\n", "from datasets import load_dataset\n", "import matplotlib.pyplot as plt\n", "from datasets import load_from_disk\n", "import numpy as np\n", "import torch\n", "from transformers import WhisperTokenizer\n", "from transformers import WhisperFeatureExtractor\n", "from transformers import WhisperForConditionalGeneration\n", "import os\n", "import shutil\n", "from IPython.display import Audio\n", "from datasets import Audio\n", "from huggingface_hub import list_repo_files\n", "from datasets import Audio\n", "from transformers import WhisperProcessor\n", "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union\n", "import evaluate\n", "from transformers import Seq2SeqTrainingArguments\n", "from transformers import TrainerCallback\n", "from transformers.trainer_pt_utils import IterableDatasetShard\n", "from torch.utils.data import IterableDataset\n", "from transformers import Seq2SeqTrainer" ] }, { "cell_type": "code", "execution_count": 3, "id": "354328b0-db56-4ed0-b82f-e6783721e489", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"NCCL_P2P_DISABLE\"] = \"1\"\n", "os.environ[\"NCCL_IB_DISABLE\"] = \"1\"" ] }, { "cell_type": "code", "execution_count": null, "id": "b7416269-d6bd-4277-acb2-bb0a0c3f6d23", "metadata": {}, "outputs": [], "source": [ "# downloading a subset of the train dataset\n", "train_files = [f\"train/scripted/audios/train_scripted_{str(i).zfill(3)}.parquet\" for i in range(15, 30)]\n", "full_train = load_dataset(\"Anv-ke/kikuyu\", data_files=train_files, split=\"train\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "876e9e1e-c953-4e86-a765-72233db3ea42", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "28465" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(full_train)" ] }, { "cell_type": "code", "execution_count": 10, "id": "19887962-5407-4f30-8bb1-1b09ee762505", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f3f8d55f2ed94b78b8941518a484f3e6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "dev/scripted/audios/dev_scripted_002.par(…): 0%| | 0.00/757M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6bf5899742274b2fa97660fff1380da5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "dev/scripted/audios/dev_scripted_003.par(…): 0%| | 0.00/313M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ca10c66838a465d89aa4f99b55aa262", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#downloading validation subset data\n", "val_files = [f\"dev/scripted/audios/dev_scripted_{str(i).zfill(3)}.parquet\" for i in range(2, 4)]\n", "full_val = load_dataset(\"Anv-ke/kikuyu\", data_files=val_files, split=\"train\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "19129b1a-dc30-4f2a-99a2-54bc4e143f55", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2681" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(full_val)" ] }, { "cell_type": "code", "execution_count": 12, "id": "e0733568-d7c8-449e-ba46-2c5a221efc0b", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29c6aa1d798a4a79a1eaee25bd54423b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "dev_test/scripted/audios/dev_test_script(…): 0%| | 0.00/747M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "90ab6da95e304d66a1971d3c0d81643d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "dev_test/scripted/audios/dev_test_script(…): 0%| | 0.00/262M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be7b26b3f5264719a388b368ee1b8b9f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#downloading test subset data\n", "test_files = [f\"dev_test/scripted/audios/dev_test_scripted_{str(i).zfill(3)}.parquet\" for i in range(2, 4)]\n", "full_test = load_dataset(\"Anv-ke/kikuyu\", data_files=test_files, split=\"train\")" ] }, { "cell_type": "code", "execution_count": 13, "id": "1554cac8-a29c-4df1-988d-deba9a4e21d3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2766" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(full_test)" ] }, { "cell_type": "code", "execution_count": 19, "id": "30cf1d4e-5783-4eb6-a2dd-6847a60fd036", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saving Train...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9d71d026b594d38951a6ef2aeafb218", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/23 shards): 0%| | 0/21166 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Saving Validation...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bbc5df074a7145dd9002ee204856ed87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/4 shards): 0%| | 0/4143 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Saving Test...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d1d7407bf3d74ee5aef9cbd5690d2492", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/4 shards): 0%| | 0/3756 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done! All data is stored in the 'kikuyu_dataset' folder.\n" ] } ], "source": [ "base_dir = \"kikuyu_dataset\"\n", "\n", "# Create the main directory and sub-directories if they don't exist\n", "os.makedirs(os.path.join(base_dir, \"train\"), exist_ok=True)\n", "os.makedirs(os.path.join(base_dir, \"validation\"), exist_ok=True)\n", "os.makedirs(os.path.join(base_dir, \"test\"), exist_ok=True)\n", "\n", "# Now save the datasets you downloaded\n", "print(\"Saving Train...\")\n", "full_train.save_to_disk(os.path.join(base_dir, \"train\"))\n", "print(\"Saving Validation...\")\n", "full_val.save_to_disk(os.path.join(base_dir, \"validation\"))\n", "print(\"Saving Test...\")\n", "full_test.save_to_disk(os.path.join(base_dir, \"test\"))\n", "\n", "print(\"Done! All data is stored in the 'kikuyu_dataset' folder.\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "9492a221-613b-40b0-8490-58b89f552aea", "metadata": {}, "outputs": [], "source": [ "# data analysis\n", "\n", "def check_corrupted_files(dataset_name, dataset_path):\n", " print(f\"\\n--- Checking {dataset_name} split ---\")\n", " ds = load_from_disk(dataset_path)\n", " \n", " corrupted_indices = []\n", " \n", " # We use a simple loop. For very large datasets, \n", " # you'd use .map(), but a loop is clearer for debugging.\n", " for i in range(len(ds)):\n", " try:\n", " # Try to access the audio array\n", " audio_array = ds[i][\"audio\"][\"array\"]\n", " \n", " # Check if the audio is empty (silence/zero length)\n", " if audio_array is None or len(audio_array) == 0:\n", " print(f\"Index {i}: Empty audio data.\")\n", " corrupted_indices.append(i)\n", " \n", " except Exception as e:\n", " print(f\"Index {i}: Failed to load. Error: {e}\")\n", " corrupted_indices.append(i)\n", " \n", " if not corrupted_indices:\n", " print(f\"✅ Success! No corrupted files found in {dataset_name}.\")\n", " else:\n", " print(f\"⚠️ Found {len(corrupted_indices)} issues in {dataset_name}.\")\n", " \n", " return corrupted_indices\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "1ee17552-601a-4e04-beee-f310766e1e49", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "--- Checking Train split ---\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c2991e9b69a8411e8d1aa2a523820e56", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading dataset from disk: 0%| | 0/23 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Index 9107: Failed to load. Error: Error opening <_io.BytesIO object at 0x40022bec48b0>: Format not recognised.\n", "Index 16317: Failed to load. Error: Error opening <_io.BytesIO object at 0x40022bec4c20>: Format not recognised.\n", "Index 18151: Failed to load. Error: Error opening <_io.BytesIO object at 0x40022bec48b0>: Format not recognised.\n", "Index 18156: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e03dd0>: Format not recognised.\n", "Index 18163: Failed to load. Error: Error opening <_io.BytesIO object at 0x40022bec4c20>: Format not recognised.\n", "⚠️ Found 5 issues in Train.\n", "\n", "--- Checking Validation split ---\n", "Index 136: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 146: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3df80>: Format not recognised.\n", "Index 162: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 184: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 205: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3df80>: Format not recognised.\n", "Index 213: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 246: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 268: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3df80>: Format not recognised.\n", "Index 281: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 290: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 321: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 334: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 337: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 341: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 347: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 624: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 639: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f420>: Format not recognised.\n", "Index 644: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f420>: Format not recognised.\n", "Index 655: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f420>: Format not recognised.\n", "Index 658: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f420>: Format not recognised.\n", "Index 669: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f420>: Format not recognised.\n", "Index 676: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 678: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 681: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 1348: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f9c0>: Format not recognised.\n", "Index 1378: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1461: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1470: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1491: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1515: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1521: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1538: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1549: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1550: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1573: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "⚠️ Found 35 issues in Validation.\n", "\n", "--- Checking Test split ---\n", "Index 728: Failed to load. Error: Error opening <_io.BytesIO object at 0x40022bda26b0>: Format not recognised.\n", "Index 774: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 793: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 914: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 915: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 947: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3f920>: Format not recognised.\n", "Index 1018: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e19f80>: Format not recognised.\n", "Index 1064: Failed to load. Error: Error opening <_io.BytesIO object at 0x4002a7e3e980>: Format not recognised.\n", "⚠️ Found 8 issues in Test.\n" ] } ], "source": [ "# Run the check on your saved folders\n", "bad_train = check_corrupted_files(\"Train\", \"./kikuyu_dataset/train\")\n", "bad_val = check_corrupted_files(\"Validation\", \"./kikuyu_dataset/validation\")\n", "bad_test = check_corrupted_files(\"Test\", \"./kikuyu_dataset/test\")" ] }, { "cell_type": "code", "execution_count": 26, "id": "43a52dbe-d695-4e41-85a5-ff2ec4520cbc", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "--- Investigating Row 9107 ---\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "43ff3b801a4f486386707d9dd363c1c1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading dataset from disk: 0%| | 0/23 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Size on disk: 319932 bytes\n", "Hex Signature: 55 6b 6c 47 52 6b 53 70 41 77 42 58\n", "ASCII Signature: b'UklGRkSpAwBX'\n", "Conclusion: Header Mismatch. The data is not a valid WAV file.\n" ] } ], "source": [ "def diagnose_final_attempt(dataset_path, index):\n", " print(f\"\\n--- Investigating Row {index} ---\")\n", " ds = load_from_disk(dataset_path)\n", " \n", " # Accessing the raw Arrow table to avoid any Decoding or Casting\n", " table = ds.data\n", " raw_row = table.slice(index, 1).to_pydict()\n", " audio_data = raw_row['audio'][0]\n", " \n", " if audio_data is None:\n", " print(\"Result: The entire audio column for this row is NULL.\")\n", " return\n", " raw_bytes = audio_data.get('bytes')\n", " \n", " if raw_bytes is None:\n", " print(\"Result: Dictionary exists but 'bytes' key is missing or empty.\")\n", " print(f\"Available keys in audio dict: {list(audio_data.keys())}\")\n", " else:\n", " byte_len = len(raw_bytes)\n", " print(f\"Size on disk: {byte_len} bytes\")\n", " \n", " # Checking the first 12 bytes for a 'Signature'\n", " signature = raw_bytes[:12]\n", " print(f\"Hex Signature: {signature.hex(' ')}\")\n", " print(f\"ASCII Signature: {signature}\")\n", "\n", " if byte_len < 500:\n", " print(\"Conclusion: File is too small. This is a corrupted stub.\")\n", " elif b'RIFF' not in signature[:4]:\n", " print(\"Conclusion: Header Mismatch. The data is not a valid WAV file.\")\n", "\n", "# checking corrupted files in a train index\n", "diagnose_final_attempt(\"./kikuyu_dataset/train\", 9107)" ] }, { "cell_type": "code", "execution_count": 27, "id": "1fc752c9-334f-4f69-8eb1-af5c6a288ef3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--- Cleaning Train ---\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7c5a4e83ee3647c9b2f98b82ae8afda3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading dataset from disk: 0%| | 0/23 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cad73190b6da4b50a844cd5b25176740", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/23 shards): 0%| | 0/21161 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done: Removed 5 files. New total: 21161\n", "--- Cleaning Validation ---\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2794e073a7874f30912756faf28f1a9d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/4 shards): 0%| | 0/4108 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done: Removed 35 files. New total: 4108\n", "--- Cleaning Test ---\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20874814cf904b3e83b9846267a9aead", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/4 shards): 0%| | 0/3748 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done: Removed 8 files. New total: 3748\n" ] } ], "source": [ "def filter_and_replace(name, bad_indices, path):\n", " print(f\"--- Cleaning {name} ---\")\n", " \n", " # 1. Load the current version\n", " ds = load_from_disk(path)\n", " total_before = len(ds)\n", " good_indices = [i for i in range(total_before) if i not in bad_indices]\n", "\n", " clean_ds = ds.select(good_indices)\n", " clean_path = f\"{path}_clean\"\n", " clean_ds.save_to_disk(clean_path)\n", " print(f\"Done: Removed {len(bad_indices)} files. New total: {len(clean_ds)}\")\n", " return clean_path\n", "\n", "# Executing the cleaning phase\n", "clean_train_path = filter_and_replace(\"Train\", bad_train, \"./kikuyu_dataset/train\")\n", "clean_val_path = filter_and_replace(\"Validation\", bad_val, \"./kikuyu_dataset/validation\")\n", "clean_test_path = filter_and_replace(\"Test\", bad_test, \"./kikuyu_dataset/test\")" ] }, { "cell_type": "code", "execution_count": 29, "id": "eb67c489-b378-462d-80ad-75f4ef97470f", "metadata": {}, "outputs": [], "source": [ "# Listing the original corrupted folders\n", "original_folders = [\"./kikuyu_dataset/train\", \"./kikuyu_dataset/validation\", \"./kikuyu_dataset/test\"]\n", "\n", "for folder in original_folders:\n", " if os.path.exists(folder):\n", " shutil.rmtree(folder) \n", " print(f\"Deleted corrupted folder: {folder}\")\n", "\n", "# Renaming data to the original names\n", "os.rename(\"./kikuyu_dataset/train_clean\", \"./kikuyu_dataset/train\")\n", "os.rename(\"./kikuyu_dataset/validation_clean\", \"./kikuyu_dataset/validation\")\n", "os.rename(\"./kikuyu_dataset/test_clean\", \"./kikuyu_dataset/test\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "10dedd72-60ca-42ed-a208-7d6d2ab8ccee", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "526b8ec57ff04047b74e9e42f8d41a87", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading dataset from disk: 0%| | 0/46 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Train: 49591 rows\n", "Val: 6789 rows\n", "Test: 6514 rows\n" ] } ], "source": [ "# Loading datasets from the disk\n", "train_ds = load_from_disk(\"./Combined_ANV_dataset/train\")\n", "val_ds = load_from_disk(\"./Combined_ANV_dataset/validation\")\n", "test_ds = load_from_disk(\"./Combined_ANV_dataset/test\")\n", "\n", "print(f\"Train: {len(train_ds)} rows\") \n", "print(f\"Val: {len(val_ds)} rows\")\n", "print(f\"Test: {len(test_ds)} rows\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "43d13e1f-1a9b-465f-8584-82783e1cc3e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Duration: 71.74 hours\n", "Val Duration: 8.39 hours\n" ] } ], "source": [ "def calculate_total_hours(dataset):\n", " # Calculate duration for each sample: length of array / sampling_rate\n", " # This might take a minute for 21k rows\n", " total_seconds = sum(len(x[\"array\"]) / x[\"sampling_rate\"] for x in dataset[\"audio\"])\n", " return total_seconds / 3600\n", "\n", "print(f\"Train Duration: {calculate_total_hours(train_ds):.2f} hours\")\n", "print(f\"Val Duration: {calculate_total_hours(val_ds):.2f} hours\")" ] }, { "cell_type": "code", "execution_count": 32, "id": "b8d5d243-b47d-44ec-86c0-645e87f5c3c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train Sampling Rate: 44100 Hz\n", "Validation Sampling Rate: 44100 Hz\n", "Test Sampling Rate: 44100 Hz\n" ] } ], "source": [ "#checking the sampling rate for the first audio\n", "def check_sampling_rate(dataset, name):\n", " sr = dataset[0][\"audio\"][\"sampling_rate\"]\n", " print(f\"{name} Sampling Rate: {sr} Hz\")\n", " return sr\n", "\n", "train_sr = check_sampling_rate(train_ds, \"Train\")\n", "val_sr = check_sampling_rate(val_ds, \"Validation\")\n", "test_sr = check_sampling_rate(test_ds, \"Test\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "e926475a-853b-4f58-b3fb-7a27c457f76d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "New Sampling Rate: 16000 Hz\n" ] } ], "source": [ "#resampling to the standard sampling rate 16khz\n", "train_ds = train_ds.cast_column(\"audio\", Audio(sampling_rate=16000))\n", "val_ds = val_ds.cast_column(\"audio\", Audio(sampling_rate=16000))\n", "test_ds = test_ds.cast_column(\"audio\", Audio(sampling_rate=16000))\n", "print(f\"New Sampling Rate: {train_ds[0]['audio']['sampling_rate']} Hz\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "52272bdd-03de-4b3d-9398-972039f0fc1d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'filename': Value(dtype='string', id=None),\n", " 'type': Value(dtype='string', id=None),\n", " 'split': Value(dtype='string', id=None),\n", " 'recorder_uuid': Value(dtype='string', id=None),\n", " 'domain': Value(dtype='string', id=None),\n", " 'transcription': Value(dtype='string', id=None),\n", " 'language': Value(dtype='string', id=None)}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#getting the features\n", "train_ds.features" ] }, { "cell_type": "code", "execution_count": 9, "id": "d93c22e2-6fb7-43a9-b003-3699ab62131d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Features: dict_keys(['audio', 'transcription'])\n" ] } ], "source": [ "#droping unnecesary columns\n", "columns_to_remove = [\"filename\", \"type\", \"split\", \"recorder_uuid\", \"domain\", \"language\"]\n", "\n", "# Applying to each split\n", "train_ds = train_ds.remove_columns(columns_to_remove)\n", "val_ds = val_ds.remove_columns(columns_to_remove)\n", "test_ds = test_ds.remove_columns(columns_to_remove)\n", "print(\" Features:\", train_ds.features.keys())" ] }, { "cell_type": "code", "execution_count": null, "id": "2f994a96-1582-4950-8701-9822d95a5d2e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "id": "8bc0befd-0880-4706-b3da-fa0a2331a6e8", "metadata": {}, "outputs": [], "source": [ "#tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-tiny\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "23a2e0ee-5319-4eba-b17b-561a05f887f2", "metadata": {}, "outputs": [], "source": [ "#calling the tokenizer and the feature extractor\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "e9b0fbad-6071-4089-87cd-bfcc0a2ab6b7", "metadata": {}, "outputs": [], "source": [ "#normalizing data\n", "do_lower_case = False\n", "do_remove_punctuation = False\n", "\n", "normalizer = BasicTextNormalizer()" ] }, { "cell_type": "code", "execution_count": 13, "id": "9a3ef6d1-315e-46b6-b758-09ac0b4125ba", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # load and (possibly) resample audio data to 16kHz\n", " audio = batch[\"audio\"]\n", "\n", " # compute log-Mel input features from input audio array \n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " # compute input length of audio sample in seconds\n", " batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", " \n", " # optional pre-processing steps\n", " transcription = batch[\"transcription\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " \n", " # encode target text to label ids\n", " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n", " return batch\n", "\n", "\n", "# def prepare_dataset(examples):\n", "# # compute log-Mel input features from input audio array \n", "# audio = examples[\"audio\"]\n", "# examples[\"input_features\"] = feature_extractor(\n", "# audio[\"array\"], sampling_rate=16000).input_features[0]\n", "# del examples[\"audio\"]\n", "# transcription = examples[\"transcription\"]\n", "\n", "# # encode target text to label ids \n", "# examples[\"labels\"] = tokenizer(transcription).input_ids\n", "# del examples[\"transcription\"]\n", "# return examples" ] }, { "cell_type": "code", "execution_count": 14, "id": "ea359827-a3c3-44f1-aac6-a95da384b291", "metadata": {}, "outputs": [], "source": [ "#mapping the train data\n", "train_ds = train_ds.map(prepare_dataset, num_proc=4)" ] }, { "cell_type": "code", "execution_count": 15, "id": "c408e4e0-5c69-4f95-9b76-74a3fe05e7bf", "metadata": {}, "outputs": [], "source": [ "train_ds.set_format(type=\"torch\", columns=[\"input_features\", \"labels\"])" ] }, { "cell_type": "code", "execution_count": 16, "id": "6e0d1fb9-b9a3-443a-acaa-0170317bb77f", "metadata": {}, "outputs": [], "source": [ "val_ds = val_ds.map(prepare_dataset, num_proc=4)\n", "val_ds.set_format(type=\"torch\", columns=[\"input_features\", \"labels\"])" ] }, { "cell_type": "code", "execution_count": 41, "id": "b665f037-7d5b-4f03-bfd1-24c3156efb5a", "metadata": {}, "outputs": [], "source": [ "max_input_length = 15.0\n", "\n", "def is_audio_in_length_range(length):\n", " return length < max_input_length" ] }, { "cell_type": "code", "execution_count": 44, "id": "05685bcb-d657-441c-988b-ea5b51bdf723", "metadata": {}, "outputs": [], "source": [ "train_ds= train_ds.filter(\n", " is_audio_in_length_range,\n", " input_columns=[\"input_length\"],\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "id": "0cea908a-5e39-4558-bd05-24c67b81b1fb", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " # spliting inputs and labels since they have to be of different lengths and need different padding methods\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " # geting the tokenized label sequences\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " # pad the labels to max length\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " # replacing padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " # if bos token is appended in previous tokenization step,\n", " # cut bos token here as it's append later anyways\n", " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": 18, "id": "c81326df-8e17-45e0-a45a-53d69acf0978", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)" ] }, { "cell_type": "code", "execution_count": 19, "id": "3fc14a71-effa-420f-a02a-c841317a8c26", "metadata": {}, "outputs": [], "source": [ "#evaluation metrics\n", "metric_wer = evaluate.load(\"wer\")\n", "metric_cer = evaluate.load(\"cer\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "397371aa-57df-4fb0-986e-fa7d06af3ff3", "metadata": {}, "outputs": [], "source": [ "# # evaluate with the 'normalised' WER\n", "# do_normalize_eval = True\n", "\n", "# def compute_metrics(pred):\n", "# pred_ids = pred.predictions\n", "# label_ids = pred.label_ids\n", "\n", "# # replace -100 with the pad_token_id\n", "# label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", "# # we do not want to group tokens when computing the metrics\n", "# pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", "# label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", "# if do_normalize_eval:\n", "# pred_str = [normalizer(pred) for pred in pred_str]\n", "# label_str = [normalizer(label) for label in label_str]\n", "\n", "# wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n", "\n", "# return {\"wer\": wer}\n", "\n", "do_normalize_eval = True\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # Replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # Decode predictions and labels\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " # Normalize if requested\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " # Compute both metrics\n", " wer = 100 * metric_wer.compute(predictions=pred_str, references=label_str)\n", " cer = 100 * metric_cer.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer, \"cer\": cer}\n" ] }, { "cell_type": "code", "execution_count": 21, "id": "ad9fac1e-66fa-4196-894b-7cfc88daa148", "metadata": {}, "outputs": [], "source": [ "#loading the model\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")" ] }, { "cell_type": "code", "execution_count": 22, "id": "d7682c32-2532-4fda-b851-8c9358c7a7da", "metadata": {}, "outputs": [], "source": [ "model.config.forced_decoder_ids = None\n", "model.config.suppress_tokens = []\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": 23, "id": "a76505d7-9bb6-4cad-a745-d7bcabd632b3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/users/mkariuki/miniconda3/envs/lenv/lib/python3.12/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "# defining training arguments\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./\",\n", " per_device_train_batch_size=64,\n", " gradient_accumulation_steps=1, \n", " learning_rate=1e-5,\n", " warmup_steps=500,\n", " max_steps=1500,\n", " gradient_checkpointing=False,\n", " fp16=True,\n", " evaluation_strategy=\"steps\",\n", " per_device_eval_batch_size=8,\n", " predict_with_generate=True,\n", " generation_max_length=225,\n", " save_steps=1000,\n", " eval_steps=200,\n", " logging_steps=25,\n", " report_to=[\"tensorboard\"],\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"wer\",\n", " dataloader_num_workers=8,\n", " greater_is_better=False,\n", " remove_unused_columns=False,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "f5ce7750-d1f1-4cc5-b01e-9e22a43d31c4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_80560/2391572708.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Seq2SeqTrainer.__init__`. Use `processing_class` instead.\n", " trainer = Seq2SeqTrainer(\n" ] } ], "source": [ "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=train_ds,\n", " eval_dataset=val_ds,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor,\n", " \n", ")" ] }, { "cell_type": "code", "execution_count": 25, "id": "8590d11d-0c79-4298-891e-8e0958dc6887", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "processor.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": 26, "id": "7baf26c9-3b97-4995-a1f4-15a3d3120dd8", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/users/mkariuki/miniconda3/envs/lenv/lib/python3.12/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "Validation Loss | \n", "Wer | \n", "Cer | \n", "
|---|---|---|---|---|
| 200 | \n", "1.173500 | \n", "1.062414 | \n", "90.777114 | \n", "35.807545 | \n", "
| 400 | \n", "0.715700 | \n", "0.696632 | \n", "62.687569 | \n", "23.310166 | \n", "
| 600 | \n", "0.569400 | \n", "0.592113 | \n", "51.939734 | \n", "18.747940 | \n", "
| 800 | \n", "0.500300 | \n", "0.549826 | \n", "48.891363 | \n", "18.606143 | \n", "
| 1000 | \n", "0.460100 | \n", "0.527498 | \n", "48.217336 | \n", "18.207214 | \n", "
| 1200 | \n", "0.449500 | \n", "0.516708 | \n", "49.516591 | \n", "19.761990 | \n", "
| 1400 | \n", "0.440700 | \n", "0.510588 | \n", "48.494876 | \n", "19.522083 | \n", "
"
],
"text/plain": [
"