diff --git "a/src/.ipynb_checkpoints/Copy_of_MusicGen-checkpoint.ipynb" "b/src/.ipynb_checkpoints/Copy_of_MusicGen-checkpoint.ipynb"
new file mode 100644--- /dev/null
+++ "b/src/.ipynb_checkpoints/Copy_of_MusicGen-checkpoint.ipynb"
@@ -0,0 +1,2633 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "70300319-d206-43ce-b3bf-3da6b079f20f",
+ "metadata": {
+ "id": "70300319-d206-43ce-b3bf-3da6b079f20f"
+ },
+ "source": [
+ "## MusicGen in π€ Transformers\n",
+ "\n",
+ "**by [Sanchit Gandhi](https://huggingface.co/sanchit-gandhi)**\n",
+ "\n",
+ "MusicGen is a Transformer-based model capable fo generating high-quality music samples conditioned on text descriptions or audio prompts. It was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet et al. from Meta AI.\n",
+ "\n",
+ "The MusicGen model can be de-composed into three distinct stages:\n",
+ "1. The text descriptions are passed through a frozen text encoder model to obtain a sequence of hidden-state representations\n",
+ "2. The MusicGen decoder is then trained to predict discrete audio tokens, or *audio codes*, conditioned on these hidden-states\n",
+ "3. These audio tokens are then decoded using an audio compression model, such as EnCodec, to recover the audio waveform\n",
+ "\n",
+ "The pre-trained MusicGen checkpoints use Google's [t5-base](https://huggingface.co/t5-base) as the text encoder model, and [EnCodec 32kHz](https://huggingface.co/facebook/encodec_32khz) as the audio compression model. The MusicGen decoder is a pure language model architecture,\n",
+ "trained from scratch on the task of music generation.\n",
+ "\n",
+ "The novelty in the MusicGen model is how the audio codes are predicted. Traditionally, each codebook has to be predicted by a separate model (i.e. hierarchically) or by continuously refining the output of the Transformer model (i.e. upsampling). MusicGen uses an efficient *token interleaving pattern*, thus eliminating the need to cascade multiple models to predict a set of codebooks. Instead, it is able to generate the full set of codebooks in a single forward pass of the decoder, resulting in much faster inference.\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "\n",
+ "\n",
+ "**Figure 1:** Codebook delay pattern used by MusicGen. Figure taken from the [MusicGen paper](https://arxiv.org/abs/2306.05284).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77",
+ "metadata": {
+ "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77"
+ },
+ "source": [
+ "## Prepare the Environment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c",
+ "metadata": {
+ "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c"
+ },
+ "source": [
+ "Letβs make sure weβre connected to a GPU to run this notebook. To get a GPU, click `Runtime` -> `Change runtime type`, then change `Hardware accelerator` from `None` to `GPU`. We can verify that weβve been assigned a GPU and view its specifications through the `nvidia-smi` command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a",
+ "outputId": "bcb3fd96-cc9c-45fc-a0ce-2fc79398cedc"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Thu Sep 7 00:23:14 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
+ "| N/A 41C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |\n",
+ "| | | N/A |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| No running processes found |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd",
+ "metadata": {
+ "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd"
+ },
+ "source": [
+ "We see here that we've got on Tesla T4 16GB GPU, although this may vary for you depending on GPU availablity and Colab GPU assignment.\n",
+ "\n",
+ "Next, we install the π€ Transformers package from the main branch, as well as π€ Datasets package to load audio files for audio-prompted generation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "66af0411-c18e-4d8b-b6d9-318ff4460e48",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "66af0411-c18e-4d8b-b6d9-318ff4460e48",
+ "outputId": "6c0be17d-ac83-49f4-eff5-4459474c0c97"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m14.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m26.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m29.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m45.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m41.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m16.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m24.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install --upgrade --quiet pip\n",
+ "!pip install --quiet git+https://github.com/huggingface/transformers.git datasets[audio]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77ee39cc-654b-4f0e-b601-013e484c16f0",
+ "metadata": {
+ "id": "77ee39cc-654b-4f0e-b601-013e484c16f0"
+ },
+ "source": [
+ "## Load the Model\n",
+ "\n",
+ "The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b0d87424-9f38-4658-ba47-2a465d52ad77",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 113,
+ "referenced_widgets": [
+ "013191f7a16c49fdbfc9b6cb3b0aa089",
+ "1c58ed64d7144bcf9c58fe0f89364d61",
+ "ce32b8aaae2f4cd38d5ec78fefaa34ce",
+ "7ed4f572d3534679a3e1e4d90880bd71",
+ "fd741db4538f493588106d753a747593",
+ "0aa3f4c09c854d90b9357d356e8ed46b",
+ "bbf7a9706dd64f6eb24d4b46ff52bc23",
+ "7f52d5efbb064170ac8d8681ae92f29b",
+ "881fd640878c420cbc14dfd5f8516953",
+ "f74afcde56c64ab389fed5bf7f5964b8",
+ "09bbff09eaa44cdf92612c7f02f05f63",
+ "0cd8b835ee724b659c92fbee8eb62327",
+ "50323b27b7ff4823b733bffa97930218",
+ "abcf0eeed63f4331bbc9a044b2d3d65f",
+ "2fa893516c3946729fa0eda21550f658",
+ "7742549b504c4e71b4d0a2119d459936",
+ "6081dfc2472b4097b01bc7608c100ca3",
+ "a810f25c280d48539d7382064588efd7",
+ "b93f98b1284b4cf8bc325e9c4d9a2bf1",
+ "a43852c0c4754235a7cf3fb7221eed1d",
+ "410cc9979cc44f5aad48d17cf6042165",
+ "768bf7c8a965476c99693c9aa3ea89cb",
+ "f938b2a799d5454f8885d5d6bba24b94",
+ "a2a0abc232c04d8c96704d6b012227aa",
+ "4d191abe24514b6a8a5bcb7aa4dd624c",
+ "9350625ba0fe43949ea335a27f8e402d",
+ "e9895aee370a4d888cefa7a82cd90c00",
+ "dbcc2550fed34fc4b9d1f9b5e7465b85",
+ "aca7c88abc684793880a4619257522c4",
+ "f83972e889134265aff6737844971e6f",
+ "9fa6a2c3f81e4a159cf652d8a48acd37",
+ "92804f5fa9ef49a094ce3d54b051ff9c",
+ "9ac4c059958446b1af03da2fe2e0ac20"
+ ]
+ },
+ "id": "b0d87424-9f38-4658-ba47-2a465d52ad77",
+ "outputId": "2af6a8e9-ff4c-4f2d-8aa2-89f35f0575e7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "013191f7a16c49fdbfc9b6cb3b0aa089",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (β¦)lve/main/config.json: 0%| | 0.00/7.87k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0cd8b835ee724b659c92fbee8eb62327",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading model.safetensors: 0%| | 0.00/2.36G [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f938b2a799d5454f8885d5d6bba24b94",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading (β¦)neration_config.json: 0%| | 0.00/224 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from transformers import MusicgenForConditionalGeneration\n",
+ "\n",
+ "model = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4981d112-407c-4120-86aa-5c6a170543f7",
+ "metadata": {
+ "id": "4981d112-407c-4120-86aa-5c6a170543f7"
+ },
+ "source": [
+ "We can then place the model on our accelerator device (if available), or leave it on the CPU otherwise:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97",
+ "metadata": {
+ "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
+ "model.to(device);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547",
+ "metadata": {
+ "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547"
+ },
+ "source": [
+ "## Generation\n",
+ "\n",
+ "MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly\n",
+ "better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,\n",
+ "and can be explicitly specified by setting `do_sample=True` in the call to `MusicgenForConditionalGeneration.generate` (see below).\n",
+ "\n",
+ "### Unconditional Generation\n",
+ "\n",
+ "The inputs for unconditional (or 'null') generation can be obtained through the method `MusicgenForConditionalGeneration.get_unconditional_inputs`. We can then run auto-regressive generation using the `.generate` method, specifying `do_sample=True` to enable sampling mode:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2",
+ "metadata": {
+ "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2"
+ },
+ "outputs": [],
+ "source": [
+ "unconditional_inputs = model.get_unconditional_inputs(num_samples=1)\n",
+ "\n",
+ "audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94cb74df-c194-4d2e-930a-12473b08a919",
+ "metadata": {
+ "id": "94cb74df-c194-4d2e-930a-12473b08a919"
+ },
+ "source": [
+ "The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen\n",
+ "to the generated audio samples, you can either play them in an ipynb notebook:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea",
+ "outputId": "9a47adc3-17b1-40d6-989d-73319c9ea7ee"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from IPython.display import Audio\n",
+ "\n",
+ "sampling_rate = model.config.audio_encoder.sampling_rate\n",
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6de58334-40f7-4924-addb-2d6ff34c0590",
+ "metadata": {
+ "id": "6de58334-40f7-4924-addb-2d6ff34c0590"
+ },
+ "source": [
+ "Or save them as a `.wav` file using a third-party library, e.g. `scipy` (note here that we also need to remove the channel dimension from our audio tensor):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "04291f52-0a75-4ddb-9eff-e853d0f17288",
+ "metadata": {
+ "id": "04291f52-0a75-4ddb-9eff-e853d0f17288"
+ },
+ "outputs": [],
+ "source": [
+ "import scipy\n",
+ "\n",
+ "scipy.io.wavfile.write(\"musicgen_out.wav\", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39",
+ "metadata": {
+ "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39"
+ },
+ "source": [
+ "The argument `max_new_tokens` specifies the number of new tokens to generate. As a rule of thumb, you can work out the length of the generated audio sample in seconds by using the frame rate of the EnCodec model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a",
+ "outputId": "c021297b-837b-4d35-e01b-34a66a33c1dd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "5.12"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audio_length_in_s = 256 / model.config.audio_encoder.frame_rate\n",
+ "\n",
+ "audio_length_in_s"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581",
+ "metadata": {
+ "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581"
+ },
+ "source": [
+ "### Text-Conditional Generation\n",
+ "\n",
+ "The model can generate an audio sample conditioned on a text prompt through use of the `MusicgenProcessor` to pre-process\n",
+ "the inputs. The pre-processed inputs can then be passed to the `.generate` method to generate text-conditional audio samples.\n",
+ "Again, we enable sampling mode by setting `do_sample=True`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "5fba4154-13f6-403a-958b-101d6eacfb6e",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "5fba4154-13f6-403a-958b-101d6eacfb6e",
+ "outputId": "a4b87d4b-1db0-49af-dac0-eb600c47cbfb"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from transformers import AutoProcessor\n",
+ "\n",
+ "processor = AutoProcessor.from_pretrained(\"facebook/musicgen-small\")\n",
+ "\n",
+ "inputs = processor(\n",
+ " text=[\"michael jackson signing pop\", \"90s rock song with loud guitars and heavy drums unicorn zombie apocalypse yep\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=9, max_new_tokens=256)\n",
+ "\n",
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "G4fcYQljBR9k",
+ "metadata": {
+ "id": "G4fcYQljBR9k"
+ },
+ "outputs": [],
+ "source": [
+ "tokens_train = model.audio_encoder(audio_values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "_FZb_Zo-Dajl",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_FZb_Zo-Dajl",
+ "outputId": "9834a745-ac10-4142-e22f-6998997304e5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting einops\n",
+ " Downloading einops-0.6.1-py3-none-any.whl (42 kB)\n",
+ "\u001b[?25l \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m0.0/42.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m42.2/42.2 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hInstalling collected packages: einops\n",
+ "Successfully installed einops-0.6.1\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install einops\n",
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "b9SOV4EWClzz",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "b9SOV4EWClzz",
+ "outputId": "94e23569-daba-4716-f3a0-0e84c389715c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 2, 4, 253])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens_train.audio_codes.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c93af9f1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokens_on_format = rearrange(tokens_train.audio_codes, \"n b c l -> (n b c) l\")\n",
+ "print(tokens_on_format.shape)\n",
+ "print(tokens_train.audio_codes[0][0][0][0:10])\n",
+ "print(tokens_on_format[0][0:10])\n",
+ "tokens_on_format_test = tokens_on_format[:, 0:252]\n",
+ "print(tokens_on_format_test.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "VHyalHz78TlY",
+ "metadata": {
+ "id": "VHyalHz78TlY"
+ },
+ "outputs": [],
+ "source": [
+ "encoder_hidden_states = model.text_encoder(inputs.input_ids, attention_mask = inputs.attention_mask).last_hidden_state"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "yi3cupmy95pp",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "yi3cupmy95pp",
+ "outputId": "55e24db1-703e-4095-e969-f8b232a43ef8"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 27, 768])"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "encoder_hidden_states.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "J4iKmH5P9fkW",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "J4iKmH5P9fkW",
+ "outputId": "17e1ef8a-8065-48d0-b724-13de90d818fd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1024"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.decoder.config.hidden_size"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "V0jitEJh9otm",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "V0jitEJh9otm",
+ "outputId": "c3285678-2d9f-4c1f-ad72-768d9b617878"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "768"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.text_encoder.config.hidden_size"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "dOnCaS3F9yfz",
+ "metadata": {
+ "id": "dOnCaS3F9yfz"
+ },
+ "outputs": [],
+ "source": [
+ "encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "GgSnW4Oy99dr",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "GgSnW4Oy99dr",
+ "outputId": "5af8f2e2-fef5-4ba0-e368-5085ee29aa4b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 27, 1024])"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "encoder_hidden_states.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "z56a5Yjqoum5",
+ "metadata": {
+ "id": "z56a5Yjqoum5"
+ },
+ "outputs": [],
+ "source": [
+ "pad_token_id = model.generation_config.pad_token_id\n",
+ "decoder_input_ids = (\n",
+ " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n",
+ " * pad_token_id\n",
+ ").to('cuda')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "xLrDtoYhqBiP",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xLrDtoYhqBiP",
+ "outputId": "78909eb6-af29-4655-d473-86359093c9d7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([8, 1])"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoder_input_ids.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "509-WwFFpqkb",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "509-WwFFpqkb",
+ "outputId": "57e716ec-7b25-456d-87da-fe1c4869b0a4"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 27, 1024])"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "l = encoder_hidden_states * inputs.attention_mask[..., None]\n",
+ "l.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "893e0411",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# make the mask like a triangle for the input of (8, 252) \n",
+ "decoder_attention_mask = torch.triu(torch.ones((252, 252), dtype=torch.long), diagonal=1).to('cuda')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "hlKZ7dJgH1VJ",
+ "metadata": {
+ "id": "hlKZ7dJgH1VJ"
+ },
+ "outputs": [],
+ "source": [
+ "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format_test, encoder_attention_mask = inputs.attention_mask, decoder_attention_mask = decoder_attention_mask)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a5fa2fd1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def training_step(labels, results_manual):\n",
+ " # as the labels are the same than the input, we should remove the first column\n",
+ " labels = labels[:, 1:]\n",
+ " # as the logits are the same than the input, we should remove the last column\n",
+ " results_manual.logits = results_manual.logits[:, :-1, :]\n",
+ " loss_fct = torch.nn.CrossEntropyLoss()\n",
+ " print(rearrange(results_manual.logits, \"c t v -> (c t) v\").shape, rearrange(labels, \"c t -> (c t)\").shape)\n",
+ " loss = loss_fct(rearrange(results_manual.logits, \"c t v -> (c t) v\"), rearrange(labels, \"c t -> (c t)\"))\n",
+ " return loss\n",
+ "\n",
+ "loss = training_step(tokens_on_format_test, results_manual)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "82367433",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# move all the stuff to the CPU\n",
+ "encoder_hidden_states = encoder_hidden_states.cpu()\n",
+ "inputs.attention_mask = inputs.attention_mask.cpu()\n",
+ "decoder_attention_mask = decoder_attention_mask.cpu()\n",
+ "tokens_on_format_test = tokens_on_format_test.cpu()\n",
+ "results_manual.logits = results_manual.logits.cpu()\n",
+ "model = model.cpu()\n",
+ "\n",
+ "\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
+ "for i in range(100):\n",
+ " results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format_test, encoder_attention_mask = inputs.attention_mask)\n",
+ " loss = training_step(tokens_on_format_test, results_manual)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " optimizer.zero_grad()\n",
+ " print(loss)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "UV1XbuCRONUA",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UV1XbuCRONUA",
+ "outputId": "0b846e0f-180d-4bd6-946f-fb29665479d1"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([8, 1, 2048])\n",
+ "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n",
+ " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(results_manual.logits.shape)\n",
+ "#results_manual.past_key_values.shape\n",
+ "print(results_manual.logits[1][0][0:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "YpFaaCefZ4mp",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YpFaaCefZ4mp",
+ "outputId": "c49266f8-4602-47ae-94c9-de472661ba5b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2048\n"
+ ]
+ }
+ ],
+ "source": [
+ "inputs = processor(\n",
+ " text=[\"michael jackson signing pop\", \"90s rock song with loud guitars and heavy drums\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ").to('cuda')\n",
+ "\n",
+ "pad_token_id = model.generation_config.pad_token_id\n",
+ "print(pad_token_id)\n",
+ "decoder_input_ids = (\n",
+ " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n",
+ " * pad_token_id\n",
+ ").to('cuda')\n",
+ "\n",
+ "end = model(**inputs, decoder_input_ids=decoder_input_ids)\n",
+ "#logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "Fdr_sVdLQF3m",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Fdr_sVdLQF3m",
+ "outputId": "d4f55f93-e76f-4a90-9253-b6fb51a9add5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': tensor([[2278, 9, 15, 40, 3, 9325, 739, 8097, 2783, 1, 0, 0,\n",
+ " 0],\n",
+ " [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7,\n",
+ " 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
+ " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "FyTh1KykP1OT",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FyTh1KykP1OT",
+ "outputId": "036edf58-44b1-40b8-cbad-137ddad5be9b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([8, 1, 2048])\n",
+ "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n",
+ " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(end.logits.shape)\n",
+ "print(end.logits[1][0][0:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "h_0zNWeSIXxC",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "h_0zNWeSIXxC",
+ "outputId": "a7a3fe57-ceca-4e35-b8a6-7c870fcfdb54"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(end.logits[1][1][1:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "-1y5EZojYzDz",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "-1y5EZojYzDz",
+ "outputId": "5db3d81c-0a18-4386-92c7-b20354e2c902"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state'])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "end.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "tL-0dkVkbeka",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tL-0dkVkbeka",
+ "outputId": "2838c153-9344-4ebc-96fb-98331e351e72"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "24"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(end.past_key_values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2X62iV8iVNRU",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "2X62iV8iVNRU",
+ "outputId": "6b97e3ff-f8b7-40a1-9139-d473354b115c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2xlbMjFTTUBd",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2xlbMjFTTUBd",
+ "outputId": "c0228562-0a8e-4198-9171-027a5d49e8fc"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 2775, 7, 2783, 1463, 28, 7981, 63, 5253, 7, 11,\n",
+ " 13353, 1, 0],\n",
+ " [ 2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437,\n",
+ " 5253, 7, 1]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs.input_ids"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0",
+ "metadata": {
+ "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0"
+ },
+ "source": [
+ "The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits\n",
+ "(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or\n",
+ "'null' prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input\n",
+ "prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,\n",
+ "use a `guidance_scale=3` (default) for text and audio-conditional generation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d391b2a1-6376-4b69-b562-4388b731cf60",
+ "metadata": {
+ "id": "d391b2a1-6376-4b69-b562-4388b731cf60"
+ },
+ "source": [
+ "### Audio-Prompted Generation\n",
+ "\n",
+ "The same `MusicgenProcessor` can be used to pre-process an audio prompt that is used for audio continuation. In the\n",
+ "following example, we load an audio file using the π€ Datasets library, pre-process it using the processor class,\n",
+ "and then forward the inputs to the model for generation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8",
+ "outputId": "81c95bfd-649d-424f-a764-0f1b881f37dd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "dataset = load_dataset(\"sanchit-gandhi/gtzan\", split=\"train\", streaming=True)\n",
+ "sample = next(iter(dataset))[\"audio\"]\n",
+ "\n",
+ "# take the first half of the audio sample\n",
+ "sample[\"array\"] = sample[\"array\"][: len(sample[\"array\"]) // 2]\n",
+ "\n",
+ "inputs = processor(\n",
+ " audio=sample[\"array\"],\n",
+ " sampling_rate=sample[\"sampling_rate\"],\n",
+ " text=[\"80s blues track with groovy saxophone\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
+ "\n",
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc",
+ "metadata": {
+ "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc"
+ },
+ "source": [
+ "To demonstrate batched audio-prompted generation, we'll slice our sample audio by two different proportions to give two audio samples of different length.\n",
+ "Since the input audio prompts vary in length, they will be *padded* to the length of the longest audio sample in the batch before being passed to the model.\n",
+ "\n",
+ "To recover the final audio samples, the `audio_values` generated can be post-processed to remove padding by using the processor class once again:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5495f568-51ca-439d-b47b-8b52e89b78f1",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "5495f568-51ca-439d-b47b-8b52e89b78f1",
+ "outputId": "68866570-000b-4239-bed5-75bc21e828aa"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sample = next(iter(dataset))[\"audio\"]\n",
+ "\n",
+ "# take the first quater of the audio sample\n",
+ "sample_1 = sample[\"array\"][: len(sample[\"array\"]) // 4]\n",
+ "\n",
+ "# take the first half of the audio sample\n",
+ "sample_2 = sample[\"array\"][: len(sample[\"array\"]) // 2]\n",
+ "\n",
+ "inputs = processor(\n",
+ " audio=[sample_1, sample_2],\n",
+ " sampling_rate=sample[\"sampling_rate\"],\n",
+ " text=[\"80s blues track with groovy saxophone\", \"90s rock song with loud guitars and heavy drums\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
+ "\n",
+ "# post-process to remove padding from the batched audio\n",
+ "audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)\n",
+ "\n",
+ "Audio(audio_values[0], rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "viwTDmzl8ZDN",
+ "metadata": {
+ "id": "viwTDmzl8ZDN"
+ },
+ "source": [
+ "## Generation Config\n",
+ "\n",
+ "The default parameters that control the generation process, such as sampling, guidance scale and number of generated tokens, can be found in the model's generation config, and updated as desired. Let's first inspect the default generation config:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0zM4notb8Y1g",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0zM4notb8Y1g",
+ "outputId": "576755e5-42b9-48bc-be13-76322837fea0"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "GenerationConfig {\n",
+ " \"_from_model_config\": true,\n",
+ " \"bos_token_id\": 2048,\n",
+ " \"decoder_start_token_id\": 2048,\n",
+ " \"do_sample\": true,\n",
+ " \"guidance_scale\": 3.0,\n",
+ " \"max_length\": 1500,\n",
+ " \"pad_token_id\": 2048,\n",
+ " \"transformers_version\": \"4.34.0.dev0\"\n",
+ "}"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.generation_config"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "DLSnSwau8jyW",
+ "metadata": {
+ "id": "DLSnSwau8jyW"
+ },
+ "source": [
+ "Alright! We see that the model defaults to using sampling mode (`do_sample=True`), a guidance scale of 3, and a maximum generation length of 1500 (which is equivalent to 30s of audio). You can update any of these attributes to change the default generation parameters:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ensSj1IB81dA",
+ "metadata": {
+ "id": "ensSj1IB81dA"
+ },
+ "outputs": [],
+ "source": [
+ "# increase the guidance scale to 4.0\n",
+ "model.generation_config.guidance_scale = 4.0\n",
+ "\n",
+ "# set the max new tokens to 256\n",
+ "model.generation_config.max_new_tokens = 256\n",
+ "\n",
+ "# set the softmax sampling temperature to 1.5\n",
+ "model.generation_config.temperature = 1.5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "UjqGnfc-9ZFJ",
+ "metadata": {
+ "id": "UjqGnfc-9ZFJ"
+ },
+ "source": [
+ "Re-running generation now will use the newly defined values in the generation config:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "KAExrhDl9YvS",
+ "metadata": {
+ "id": "KAExrhDl9YvS"
+ },
+ "outputs": [],
+ "source": [
+ "audio_values = model.generate(**inputs.to(device))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "HdGdoGAs84hS",
+ "metadata": {
+ "id": "HdGdoGAs84hS"
+ },
+ "source": [
+ "Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting `do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the generation config."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "s__neSDH89q0",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "s__neSDH89q0",
+ "outputId": "8ec39d3c-ac4e-4cce-ee89-1e936a453859"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 1, 642560])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audio_values.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "EOMbP5f-imWD",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "id": "EOMbP5f-imWD",
+ "outputId": "df158b75-bcc9-4304-db5c-ddb0d1490c49"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "'import os\\nimport torchaudio\\nimport numpy as np\\nimport torch\\nfrom tqdm import tqdm'"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\"\"\"import os\n",
+ "import torchaudio\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from tqdm import tqdm\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "knVkOKRxhMAJ",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 192
+ },
+ "id": "knVkOKRxhMAJ",
+ "outputId": "6b6d0063-70bf-4bd8-822a-5c2f2b6363dc"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "'from datasets import load_dataset, Audio\\nfrom transformers import EncodecModel, AutoProcessor\\n\\n\\n# load a demonstration datasets\\nlibrispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\\n\\n# load the model + processor (for pre-processing the audio)\\nmodel = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\\nprocessor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\\n\\npath = \\'/content/Stim_Test_Run01_01_rock.wav\\'\\naudio_loaded, sr = torchaudio.load(path)\\nprint(audio_loaded.shape)\\naudio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\\n#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\\nprint(audio_loaded.shape)\\n# cast the audio data to the correct sampling rate for the model\\n#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\\n#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\\n\\n# pre-process the inputs\\nprint(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\\ninputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\\n\\n# explicitly encode then decode the audio inputs\\nencoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\\naudio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\\n\\n# or the equivalent with a forward pass\\naudio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\\n'"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\"\"\"from datasets import load_dataset, Audio\n",
+ "from transformers import EncodecModel, AutoProcessor\n",
+ "\n",
+ "\n",
+ "# load a demonstration datasets\n",
+ "librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
+ "\n",
+ "# load the model + processor (for pre-processing the audio)\n",
+ "model = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\n",
+ "processor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\n",
+ "\n",
+ "path = '/content/Stim_Test_Run01_01_rock.wav'\n",
+ "audio_loaded, sr = torchaudio.load(path)\n",
+ "print(audio_loaded.shape)\n",
+ "audio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\n",
+ "#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\n",
+ "print(audio_loaded.shape)\n",
+ "# cast the audio data to the correct sampling rate for the model\n",
+ "#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n",
+ "#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\n",
+ "\n",
+ "# pre-process the inputs\n",
+ "print(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\n",
+ "inputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\n",
+ "\n",
+ "# explicitly encode then decode the audio inputs\n",
+ "encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n",
+ "audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\n",
+ "\n",
+ "# or the equivalent with a forward pass\n",
+ "audio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\n",
+ "\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "FZ-JP_DFsD_6",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FZ-JP_DFsD_6",
+ "outputId": "e462f0f5-8267-449f-e8b3-c11a0a7dd593"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 1, 642560])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audio_values.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "F-b6Q4TSjGcc",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 175
+ },
+ "id": "F-b6Q4TSjGcc",
+ "outputId": "cfe3e974-ca58-48de-833f-12b1d8e4e1b8"
+ },
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "ignored",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m| \u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maudio_codes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m: name 'encoder_outputs' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "encoder_outputs.audio_codes.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "Wv4xZ1gYlU6S",
+ "metadata": {
+ "id": "Wv4xZ1gYlU6S"
+ },
+ "outputs": [],
+ "source": [
+ "encoder_outputs.audio_codes.shape"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.6"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "013191f7a16c49fdbfc9b6cb3b0aa089": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1c58ed64d7144bcf9c58fe0f89364d61",
+ "IPY_MODEL_ce32b8aaae2f4cd38d5ec78fefaa34ce",
+ "IPY_MODEL_7ed4f572d3534679a3e1e4d90880bd71"
+ ],
+ "layout": "IPY_MODEL_fd741db4538f493588106d753a747593"
+ }
+ },
+ "09bbff09eaa44cdf92612c7f02f05f63": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "0aa3f4c09c854d90b9357d356e8ed46b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0cd8b835ee724b659c92fbee8eb62327": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_50323b27b7ff4823b733bffa97930218",
+ "IPY_MODEL_abcf0eeed63f4331bbc9a044b2d3d65f",
+ "IPY_MODEL_2fa893516c3946729fa0eda21550f658"
+ ],
+ "layout": "IPY_MODEL_7742549b504c4e71b4d0a2119d459936"
+ }
+ },
+ "1c58ed64d7144bcf9c58fe0f89364d61": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0aa3f4c09c854d90b9357d356e8ed46b",
+ "placeholder": "β",
+ "style": "IPY_MODEL_bbf7a9706dd64f6eb24d4b46ff52bc23",
+ "value": "Downloading (β¦)lve/main/config.json: 100%"
+ }
+ },
+ "2fa893516c3946729fa0eda21550f658": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_410cc9979cc44f5aad48d17cf6042165",
+ "placeholder": "β",
+ "style": "IPY_MODEL_768bf7c8a965476c99693c9aa3ea89cb",
+ "value": " 2.36G/2.36G [00:22<00:00, 235MB/s]"
+ }
+ },
+ "410cc9979cc44f5aad48d17cf6042165": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4d191abe24514b6a8a5bcb7aa4dd624c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f83972e889134265aff6737844971e6f",
+ "max": 224,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9fa6a2c3f81e4a159cf652d8a48acd37",
+ "value": 224
+ }
+ },
+ "50323b27b7ff4823b733bffa97930218": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6081dfc2472b4097b01bc7608c100ca3",
+ "placeholder": "β",
+ "style": "IPY_MODEL_a810f25c280d48539d7382064588efd7",
+ "value": "Downloading model.safetensors: 100%"
+ }
+ },
+ "6081dfc2472b4097b01bc7608c100ca3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "768bf7c8a965476c99693c9aa3ea89cb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7742549b504c4e71b4d0a2119d459936": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7ed4f572d3534679a3e1e4d90880bd71": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f74afcde56c64ab389fed5bf7f5964b8",
+ "placeholder": "β",
+ "style": "IPY_MODEL_09bbff09eaa44cdf92612c7f02f05f63",
+ "value": " 7.87k/7.87k [00:00<00:00, 288kB/s]"
+ }
+ },
+ "7f52d5efbb064170ac8d8681ae92f29b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "881fd640878c420cbc14dfd5f8516953": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "92804f5fa9ef49a094ce3d54b051ff9c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9350625ba0fe43949ea335a27f8e402d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_92804f5fa9ef49a094ce3d54b051ff9c",
+ "placeholder": "β",
+ "style": "IPY_MODEL_9ac4c059958446b1af03da2fe2e0ac20",
+ "value": " 224/224 [00:00<00:00, 13.9kB/s]"
+ }
+ },
+ "9ac4c059958446b1af03da2fe2e0ac20": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9fa6a2c3f81e4a159cf652d8a48acd37": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a2a0abc232c04d8c96704d6b012227aa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dbcc2550fed34fc4b9d1f9b5e7465b85",
+ "placeholder": "β",
+ "style": "IPY_MODEL_aca7c88abc684793880a4619257522c4",
+ "value": "Downloading (β¦)neration_config.json: 100%"
+ }
+ },
+ "a43852c0c4754235a7cf3fb7221eed1d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a810f25c280d48539d7382064588efd7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "abcf0eeed63f4331bbc9a044b2d3d65f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b93f98b1284b4cf8bc325e9c4d9a2bf1",
+ "max": 2364427288,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a43852c0c4754235a7cf3fb7221eed1d",
+ "value": 2364427288
+ }
+ },
+ "aca7c88abc684793880a4619257522c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b93f98b1284b4cf8bc325e9c4d9a2bf1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bbf7a9706dd64f6eb24d4b46ff52bc23": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ce32b8aaae2f4cd38d5ec78fefaa34ce": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7f52d5efbb064170ac8d8681ae92f29b",
+ "max": 7866,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_881fd640878c420cbc14dfd5f8516953",
+ "value": 7866
+ }
+ },
+ "dbcc2550fed34fc4b9d1f9b5e7465b85": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e9895aee370a4d888cefa7a82cd90c00": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f74afcde56c64ab389fed5bf7f5964b8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f83972e889134265aff6737844971e6f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f938b2a799d5454f8885d5d6bba24b94": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_a2a0abc232c04d8c96704d6b012227aa",
+ "IPY_MODEL_4d191abe24514b6a8a5bcb7aa4dd624c",
+ "IPY_MODEL_9350625ba0fe43949ea335a27f8e402d"
+ ],
+ "layout": "IPY_MODEL_e9895aee370a4d888cefa7a82cd90c00"
+ }
+ },
+ "fd741db4538f493588106d753a747593": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
|