diff --git "a/src/Copy_of_MusicGen.ipynb" "b/src/Copy_of_MusicGen.ipynb"
new file mode 100644--- /dev/null
+++ "b/src/Copy_of_MusicGen.ipynb"
@@ -0,0 +1,2880 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "70300319-d206-43ce-b3bf-3da6b079f20f",
+ "metadata": {
+ "id": "70300319-d206-43ce-b3bf-3da6b079f20f"
+ },
+ "source": [
+ "## MusicGen in 🤗 Transformers\n",
+ "\n",
+ "**by [Sanchit Gandhi](https://huggingface.co/sanchit-gandhi)**\n",
+ "\n",
+ "MusicGen is a Transformer-based model capable fo generating high-quality music samples conditioned on text descriptions or audio prompts. It was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet et al. from Meta AI.\n",
+ "\n",
+ "The MusicGen model can be de-composed into three distinct stages:\n",
+ "1. The text descriptions are passed through a frozen text encoder model to obtain a sequence of hidden-state representations\n",
+ "2. The MusicGen decoder is then trained to predict discrete audio tokens, or *audio codes*, conditioned on these hidden-states\n",
+ "3. These audio tokens are then decoded using an audio compression model, such as EnCodec, to recover the audio waveform\n",
+ "\n",
+ "The pre-trained MusicGen checkpoints use Google's [t5-base](https://huggingface.co/t5-base) as the text encoder model, and [EnCodec 32kHz](https://huggingface.co/facebook/encodec_32khz) as the audio compression model. The MusicGen decoder is a pure language model architecture,\n",
+ "trained from scratch on the task of music generation.\n",
+ "\n",
+ "The novelty in the MusicGen model is how the audio codes are predicted. Traditionally, each codebook has to be predicted by a separate model (i.e. hierarchically) or by continuously refining the output of the Transformer model (i.e. upsampling). MusicGen uses an efficient *token interleaving pattern*, thus eliminating the need to cascade multiple models to predict a set of codebooks. Instead, it is able to generate the full set of codebooks in a single forward pass of the decoder, resulting in much faster inference.\n",
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ "\n",
+ "\n",
+ "**Figure 1:** Codebook delay pattern used by MusicGen. Figure taken from the [MusicGen paper](https://arxiv.org/abs/2306.05284).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77",
+ "metadata": {
+ "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77"
+ },
+ "source": [
+ "## Prepare the Environment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c",
+ "metadata": {
+ "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c"
+ },
+ "source": [
+ "Let’s make sure we’re connected to a GPU to run this notebook. To get a GPU, click `Runtime` -> `Change runtime type`, then change `Hardware accelerator` from `None` to `GPU`. We can verify that we’ve been assigned a GPU and view its specifications through the `nvidia-smi` command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a",
+ "outputId": "bcb3fd96-cc9c-45fc-a0ce-2fc79398cedc",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Thu Sep 7 22:36:00 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n",
+ "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n",
+ "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n",
+ "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n",
+ "| N/A 43C P0 60W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n",
+ "| N/A 44C P0 57W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n",
+ "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n",
+ "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n",
+ "| N/A 43C P0 58W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| No running processes found |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd",
+ "metadata": {
+ "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd"
+ },
+ "source": [
+ "We see here that we've got on Tesla T4 16GB GPU, although this may vary for you depending on GPU availablity and Colab GPU assignment.\n",
+ "\n",
+ "Next, we install the 🤗 Transformers package from the main branch, as well as 🤗 Datasets package to load audio files for audio-prompted generation:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77ee39cc-654b-4f0e-b601-013e484c16f0",
+ "metadata": {
+ "id": "77ee39cc-654b-4f0e-b601-013e484c16f0"
+ },
+ "source": [
+ "## Load the Model\n",
+ "\n",
+ "The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "id": "b0d87424-9f38-4658-ba47-2a465d52ad77",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 113,
+ "referenced_widgets": [
+ "013191f7a16c49fdbfc9b6cb3b0aa089",
+ "1c58ed64d7144bcf9c58fe0f89364d61",
+ "ce32b8aaae2f4cd38d5ec78fefaa34ce",
+ "7ed4f572d3534679a3e1e4d90880bd71",
+ "fd741db4538f493588106d753a747593",
+ "0aa3f4c09c854d90b9357d356e8ed46b",
+ "bbf7a9706dd64f6eb24d4b46ff52bc23",
+ "7f52d5efbb064170ac8d8681ae92f29b",
+ "881fd640878c420cbc14dfd5f8516953",
+ "f74afcde56c64ab389fed5bf7f5964b8",
+ "09bbff09eaa44cdf92612c7f02f05f63",
+ "0cd8b835ee724b659c92fbee8eb62327",
+ "50323b27b7ff4823b733bffa97930218",
+ "abcf0eeed63f4331bbc9a044b2d3d65f",
+ "2fa893516c3946729fa0eda21550f658",
+ "7742549b504c4e71b4d0a2119d459936",
+ "6081dfc2472b4097b01bc7608c100ca3",
+ "a810f25c280d48539d7382064588efd7",
+ "b93f98b1284b4cf8bc325e9c4d9a2bf1",
+ "a43852c0c4754235a7cf3fb7221eed1d",
+ "410cc9979cc44f5aad48d17cf6042165",
+ "768bf7c8a965476c99693c9aa3ea89cb",
+ "f938b2a799d5454f8885d5d6bba24b94",
+ "a2a0abc232c04d8c96704d6b012227aa",
+ "4d191abe24514b6a8a5bcb7aa4dd624c",
+ "9350625ba0fe43949ea335a27f8e402d",
+ "e9895aee370a4d888cefa7a82cd90c00",
+ "dbcc2550fed34fc4b9d1f9b5e7465b85",
+ "aca7c88abc684793880a4619257522c4",
+ "f83972e889134265aff6737844971e6f",
+ "9fa6a2c3f81e4a159cf652d8a48acd37",
+ "92804f5fa9ef49a094ce3d54b051ff9c",
+ "9ac4c059958446b1af03da2fe2e0ac20"
+ ]
+ },
+ "id": "b0d87424-9f38-4658-ba47-2a465d52ad77",
+ "outputId": "2af6a8e9-ff4c-4f2d-8aa2-89f35f0575e7",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import MusicgenForConditionalGeneration\n",
+ "cache_dir = '/fsx/proj-fmri/ckadirt/b2m/cache/'\n",
+ "model = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\", cache_dir=cache_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4981d112-407c-4120-86aa-5c6a170543f7",
+ "metadata": {
+ "id": "4981d112-407c-4120-86aa-5c6a170543f7"
+ },
+ "source": [
+ "We can then place the model on our accelerator device (if available), or leave it on the CPU otherwise:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97",
+ "metadata": {
+ "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "\n",
+ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
+ "model.to(device);"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ba5cdeee-27d0-4834-b7dc-4403864550f7",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Thu Sep 7 22:37:19 2023 \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n",
+ "|-------------------------------+----------------------+----------------------+\n",
+ "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
+ "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
+ "| | | MIG M. |\n",
+ "|===============================+======================+======================|\n",
+ "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n",
+ "| N/A 45C P0 78W / 400W | 3593MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n",
+ "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n",
+ "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n",
+ "| N/A 43C P0 60W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n",
+ "| N/A 44C P0 57W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n",
+ "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n",
+ "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n",
+ "| N/A 43C P0 58W / 400W | 3MiB / 40960MiB | 0% Default |\n",
+ "| | | Disabled |\n",
+ "+-------------------------------+----------------------+----------------------+\n",
+ " \n",
+ "+-----------------------------------------------------------------------------+\n",
+ "| Processes: |\n",
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
+ "| ID ID Usage |\n",
+ "|=============================================================================|\n",
+ "| 0 N/A N/A 742410 C ...3/envs/mindeye/bin/python 3590MiB |\n",
+ "+-----------------------------------------------------------------------------+\n"
+ ]
+ }
+ ],
+ "source": [
+ "!nvidia-smi"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547",
+ "metadata": {
+ "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547"
+ },
+ "source": [
+ "## Generation\n",
+ "\n",
+ "MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly\n",
+ "better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,\n",
+ "and can be explicitly specified by setting `do_sample=True` in the call to `MusicgenForConditionalGeneration.generate` (see below).\n",
+ "\n",
+ "### Unconditional Generation\n",
+ "\n",
+ "The inputs for unconditional (or 'null') generation can be obtained through the method `MusicgenForConditionalGeneration.get_unconditional_inputs`. We can then run auto-regressive generation using the `.generate` method, specifying `do_sample=True` to enable sampling mode:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2",
+ "metadata": {
+ "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "unconditional_inputs = model.get_unconditional_inputs(num_samples=1)\n",
+ "\n",
+ "audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94cb74df-c194-4d2e-930a-12473b08a919",
+ "metadata": {
+ "id": "94cb74df-c194-4d2e-930a-12473b08a919"
+ },
+ "source": [
+ "The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen\n",
+ "to the generated audio samples, you can either play them in an ipynb notebook:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea",
+ "outputId": "9a47adc3-17b1-40d6-989d-73319c9ea7ee",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from IPython.display import Audio\n",
+ "\n",
+ "sampling_rate = model.config.audio_encoder.sampling_rate\n",
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6de58334-40f7-4924-addb-2d6ff34c0590",
+ "metadata": {
+ "id": "6de58334-40f7-4924-addb-2d6ff34c0590"
+ },
+ "source": [
+ "Or save them as a `.wav` file using a third-party library, e.g. `scipy` (note here that we also need to remove the channel dimension from our audio tensor):"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "04291f52-0a75-4ddb-9eff-e853d0f17288",
+ "metadata": {
+ "id": "04291f52-0a75-4ddb-9eff-e853d0f17288"
+ },
+ "outputs": [],
+ "source": [
+ "import scipy\n",
+ "\n",
+ "scipy.io.wavfile.write(\"musicgen_out.wav\", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39",
+ "metadata": {
+ "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39"
+ },
+ "source": [
+ "The argument `max_new_tokens` specifies the number of new tokens to generate. As a rule of thumb, you can work out the length of the generated audio sample in seconds by using the frame rate of the EnCodec model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a",
+ "outputId": "c021297b-837b-4d35-e01b-34a66a33c1dd",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "5.12"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audio_length_in_s = 256 / model.config.audio_encoder.frame_rate\n",
+ "\n",
+ "audio_length_in_s"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581",
+ "metadata": {
+ "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581"
+ },
+ "source": [
+ "### Text-Conditional Generation\n",
+ "\n",
+ "The model can generate an audio sample conditioned on a text prompt through use of the `MusicgenProcessor` to pre-process\n",
+ "the inputs. The pre-processed inputs can then be passed to the `.generate` method to generate text-conditional audio samples.\n",
+ "Again, we enable sampling mode by setting `do_sample=True`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 59,
+ "id": "5fba4154-13f6-403a-958b-101d6eacfb6e",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "5fba4154-13f6-403a-958b-101d6eacfb6e",
+ "outputId": "a4b87d4b-1db0-49af-dac0-eb600c47cbfb",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 59,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from transformers import AutoProcessor\n",
+ "from einops import rearrange\n",
+ "\n",
+ "processor = AutoProcessor.from_pretrained(\"facebook/musicgen-small\")\n",
+ "\n",
+ "inputs = processor(\n",
+ " text=[\"hiphop beat\", \"90s rock song with loud guitars and heavy drums\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
+ "\n",
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "id": "_FZb_Zo-Dajl",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "_FZb_Zo-Dajl",
+ "outputId": "9834a745-ac10-4142-e22f-6998997304e5",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 2, 4, 253])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " tokens_train = model.audio_encoder(audio_values)\n",
+ " print(tokens_train.audio_codes.shape)\n",
+ " tokens_on_format = rearrange(tokens_train.audio_codes, \"n b c l -> (n b c) l\")\n",
+ " # make a copy of the tokens excluding to the computation graph\n",
+ " tokens_on_format = tokens_on_format.detach().clone()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 61,
+ "id": "VHyalHz78TlY",
+ "metadata": {
+ "id": "VHyalHz78TlY",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([2, 13, 768])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " encoder_hidden_states = model.text_encoder(inputs.input_ids, attention_mask = inputs.attention_mask.detach().clone()).last_hidden_state.detach().clone()\n",
+ " print(encoder_hidden_states.shape)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "id": "dOnCaS3F9yfz",
+ "metadata": {
+ "id": "dOnCaS3F9yfz",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([2, 13, 1024])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states).detach().clone()\n",
+ " print(encoder_hidden_states.shape)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
+ "id": "z56a5Yjqoum5",
+ "metadata": {
+ "id": "z56a5Yjqoum5",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "pad_token_id = model.generation_config.pad_token_id\n",
+ "decoder_input_ids = (\n",
+ " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n",
+ " * pad_token_id\n",
+ ").to(device).detach().clone()\n",
+ "\n",
+ "inputs.attention_mask = inputs.attention_mask.detach().clone()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
+ "id": "xLrDtoYhqBiP",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "xLrDtoYhqBiP",
+ "outputId": "78909eb6-af29-4655-d473-86359093c9d7",
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([8, 1])"
+ ]
+ },
+ "execution_count": 67,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "decoder_input_ids.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 107,
+ "id": "b67b9b42-fd04-4b0c-ac36-a5b7016ada16",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([8, 253])"
+ ]
+ },
+ "execution_count": 107,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens_on_format.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 71,
+ "id": "hlKZ7dJgH1VJ",
+ "metadata": {
+ "id": "hlKZ7dJgH1VJ",
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 72,
+ "id": "62b31feb-e392-4b5c-9855-25d97364256b",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([8, 252, 2048])"
+ ]
+ },
+ "execution_count": 72,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "results_manual.logits.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "id": "a5fa2fd1",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([2008, 2048]) torch.Size([2008])\n"
+ ]
+ }
+ ],
+ "source": [
+ "def training_step(labels, results_manual):\n",
+ " # as the labels are the same than the input, we should remove the first column\n",
+ " labels = labels[:, 1:]\n",
+ " # as the logits are the same than the input, we should remove the last column\n",
+ " results_manual.logits = results_manual.logits[:, :-1, :]\n",
+ " loss_fct = torch.nn.CrossEntropyLoss()\n",
+ " print(rearrange(results_manual.logits, \"c t v -> (c t) v\").shape, rearrange(labels, \"c t -> (c t)\").shape)\n",
+ " loss = loss_fct(rearrange(results_manual.logits, \"c t v -> (c t) v\"), rearrange(labels, \"c t -> (c t)\"))\n",
+ " return loss\n",
+ "\n",
+ "loss = training_step(tokens_on_format_test, results_manual)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "id": "44f38b39-c06b-4e50-b233-ef6b6f829de5",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(10.8952, device='cuda:0', grad_fn=)"
+ ]
+ },
+ "execution_count": 74,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "loss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "id": "4022c586-423e-489e-9647-c0c24a378092",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "model.zero_grad()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "06335817-59a3-47b2-a2df-10d71acdc3c4",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "tokens_on_format_test_tt = tokens_on_format_test.detach()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 97,
+ "id": "82367433",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3841, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(6.5876, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(5.2953, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(3.9745, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(3.5427, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(3.1808, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.9803, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.8768, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.7334, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.6114, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.5443, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.4862, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.4340, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.3890, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.3461, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.3014, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.2524, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.2108, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.1693, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.1228, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.0805, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.0398, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(2.0003, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.9617, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.9265, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.8898, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.8527, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.8144, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.7804, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.7465, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.7152, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.6817, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.6494, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.6182, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5891, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5631, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5393, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5514, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5996, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5329, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5090, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.4846, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.4633, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.4849, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3862, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.4540, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3665, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3937, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3255, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3280, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2869, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2772, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2436, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2343, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2054, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1852, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1706, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1331, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1386, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1027, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0897, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0732, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1328, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.4694, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3952, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2203, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2310, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1787, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.2025, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1772, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1184, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0861, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0596, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0484, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0140, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0278, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9599, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9769, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9378, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9275, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9106, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.8895, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.8961, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9111, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1137, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5037, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.5237, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3981, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1750, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.3016, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1642, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.1759, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0935, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0816, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0367, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(1.0183, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9768, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9370, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.9176, device='cuda:0', grad_fn=)\n",
+ "torch.Size([2008, 2048]) torch.Size([2008])\n",
+ "tensor(0.8828, device='cuda:0', grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# move all the stuff to the CPU\n",
+ "#encoder_hidden_states = encoder_hidden_states.cpu()\n",
+ "#inputs.attention_mask = inputs.attention_mask.cpu()\n",
+ "#decoder_attention_mask = decoder_attention_mask.cpu()\n",
+ "#tokens_on_format_test = tokens_on_format_test.cpu()\n",
+ "#results_manual.logits = results_manual.logits.cpu()\n",
+ "#model = model.cpu()\n",
+ "\n",
+ "\n",
+ "optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)\n",
+ "for i in range(100):\n",
+ " optimizer.zero_grad()\n",
+ " results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)\n",
+ " loss = training_step(tokens_on_format[:,:-1].detach().clone(), results_manual)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " \n",
+ " \n",
+ " print(loss)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 98,
+ "id": "c9870eac-70ca-4ff5-a4cf-2138bef89de1",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([ 578, 68, 244, 753, 135, 135, 40, 244, 1358, 83, 1265, 91,\n",
+ " 555, 816, 236, 235, 235, 244, 135, 135, 609, 244, 609, 135,\n",
+ " 1081, 480, 549, 14, 327, 613, 244, 244, 83, 135, 66, 244,\n",
+ " 472, 403, 472, 68, 1757, 753, 1348, 289, 609, 289, 609, 289,\n",
+ " 609, 539, 578, 472, 83, 609, 66, 609, 66, 609, 289, 609,\n",
+ " 289, 289, 403, 570, 23, 116, 14, 376, 609, 1757, 135, 1358,\n",
+ " 83, 40, 1205, 578, 444, 68, 135, 1358, 289, 83, 289, 1265,\n",
+ " 1757, 289, 289, 1081, 1036, 68, 1453, 1265, 1265, 1265, 289, 8,\n",
+ " 376, 376, 289, 778, 403, 444, 68, 1265, 1757, 1453, 376, 376,\n",
+ " 376, 289, 609, 1265, 539, 434, 334, 1986, 83, 244, 235, 244,\n",
+ " 1265, 609, 244, 1265, 1453, 172, 1036, 68, 83, 1453, 1453, 609,\n",
+ " 1453, 1265, 1265, 1265, 1265, 1205, 555, 588, 116, 235, 244, 244,\n",
+ " 172, 275, 68, 376, 753, 135, 376, 986, 1348, 609, 8, 609,\n",
+ " 609, 609, 609, 289, 609, 609, 778, 403, 1360, 68, 1265, 21,\n",
+ " 1358, 1453, 40, 1453, 289, 1453, 289, 91, 578, 226, 1265, 1265,\n",
+ " 376, 1265, 289, 1453, 289, 244, 135, 244, 403, 1036, 425, 609,\n",
+ " 244, 1043, 244, 1348, 135, 1348, 244, 1043, 1488, 403, 425, 40,\n",
+ " 83, 244, 254, 40, 135, 1265, 135, 40, 244, 979, 759, 834,\n",
+ " 327, 135, 289, 235, 235, 1358, 327, 289, 244, 778, 403, 1953,\n",
+ " 1617, 244, 83, 14, 135, 1358, 244, 40, 135, 83, 539, 578,\n",
+ " 68, 83, 83, 1453, 1265, 135, 40, 289, 1453, 609, 1453, 403,\n",
+ " 275], device='cuda:0')"
+ ]
+ },
+ "execution_count": 98,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokens_on_format[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "id": "61b8b0e2-8185-4801-9ea3-8b05820ed062",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "id": "2a1cfaa7-707c-4ed1-8a74-14d595a2d7bf",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "tensor([ 135, 244, 235, 609, 1265, 40, 1358, 68, 83, 816, 236, 555,\n",
+ " 549, 753, 91, 1081, 480, 14, 327, 613, 116, 23, 1757, 1379,\n",
+ " 21, 289, 588, 10, 376, 2007, 148, 146, 789, 1602, 425, 429,\n",
+ " 570, 1132, 172, 1522, 403, 1488, 1846, 1043, 8, 66, 902, 254,\n",
+ " 539, 1106, 472, 374, 1453, 1348, 1671, 226, 778, 1195, 1360, 275,\n",
+ " 1334, 759, 1875, 949, 993, 1194, 444, 1796, 97, 1676, 1187, 1304,\n",
+ " 1953, 492, 1698, 578, 979, 1036, 1775, 434, 1210, 1617, 15, 149,\n",
+ " 276, 1487, 1540, 986, 1398, 721, 243, 4, 1748, 131, 1986, 1544,\n",
+ " 223, 227, 1109, 1815], device='cuda:0')\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_idx = torch.topk(results_manual.logits[0][9], k=100)[1]\n",
+ "print(max_idx)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "UV1XbuCRONUA",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "UV1XbuCRONUA",
+ "outputId": "0b846e0f-180d-4bd6-946f-fb29665479d1"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([8, 1, 2048])\n",
+ "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n",
+ " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(results_manual.logits.shape)\n",
+ "#results_manual.past_key_values.shape\n",
+ "print(results_manual.logits[1][0][0:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "YpFaaCefZ4mp",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YpFaaCefZ4mp",
+ "outputId": "c49266f8-4602-47ae-94c9-de472661ba5b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2048\n"
+ ]
+ }
+ ],
+ "source": [
+ "inputs = processor(\n",
+ " text=[\"michael jackson signing pop\", \"90s rock song with loud guitars and heavy drums\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ").to('cuda')\n",
+ "\n",
+ "pad_token_id = model.generation_config.pad_token_id\n",
+ "print(pad_token_id)\n",
+ "decoder_input_ids = (\n",
+ " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n",
+ " * pad_token_id\n",
+ ").to('cuda')\n",
+ "\n",
+ "end = model(**inputs, decoder_input_ids=decoder_input_ids)\n",
+ "#logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "Fdr_sVdLQF3m",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Fdr_sVdLQF3m",
+ "outputId": "d4f55f93-e76f-4a90-9253-b6fb51a9add5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': tensor([[2278, 9, 15, 40, 3, 9325, 739, 8097, 2783, 1, 0, 0,\n",
+ " 0],\n",
+ " [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7,\n",
+ " 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
+ " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "FyTh1KykP1OT",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FyTh1KykP1OT",
+ "outputId": "036edf58-44b1-40b8-cbad-137ddad5be9b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([8, 1, 2048])\n",
+ "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n",
+ " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(end.logits.shape)\n",
+ "print(end.logits[1][0][0:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "h_0zNWeSIXxC",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "h_0zNWeSIXxC",
+ "outputId": "a7a3fe57-ceca-4e35-b8a6-7c870fcfdb54"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048],\n",
+ " [2048]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "print(end.logits[1][1][1:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "-1y5EZojYzDz",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "-1y5EZojYzDz",
+ "outputId": "5db3d81c-0a18-4386-92c7-b20354e2c902"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state'])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "end.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "tL-0dkVkbeka",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "tL-0dkVkbeka",
+ "outputId": "2838c153-9344-4ebc-96fb-98331e351e72"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "24"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(end.past_key_values)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2X62iV8iVNRU",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "2X62iV8iVNRU",
+ "outputId": "6b97e3ff-f8b7-40a1-9139-d473354b115c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2xlbMjFTTUBd",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2xlbMjFTTUBd",
+ "outputId": "c0228562-0a8e-4198-9171-027a5d49e8fc"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 2775, 7, 2783, 1463, 28, 7981, 63, 5253, 7, 11,\n",
+ " 13353, 1, 0],\n",
+ " [ 2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437,\n",
+ " 5253, 7, 1]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "inputs.input_ids"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0",
+ "metadata": {
+ "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0"
+ },
+ "source": [
+ "The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits\n",
+ "(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or\n",
+ "'null' prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input\n",
+ "prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,\n",
+ "use a `guidance_scale=3` (default) for text and audio-conditional generation."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d391b2a1-6376-4b69-b562-4388b731cf60",
+ "metadata": {
+ "id": "d391b2a1-6376-4b69-b562-4388b731cf60"
+ },
+ "source": [
+ "### Audio-Prompted Generation\n",
+ "\n",
+ "The same `MusicgenProcessor` can be used to pre-process an audio prompt that is used for audio continuation. In the\n",
+ "following example, we load an audio file using the 🤗 Datasets library, pre-process it using the processor class,\n",
+ "and then forward the inputs to the model for generation:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8",
+ "outputId": "81c95bfd-649d-424f-a764-0f1b881f37dd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "dataset = load_dataset(\"sanchit-gandhi/gtzan\", split=\"train\", streaming=True)\n",
+ "sample = next(iter(dataset))[\"audio\"]\n",
+ "\n",
+ "# take the first half of the audio sample\n",
+ "sample[\"array\"] = sample[\"array\"][: len(sample[\"array\"]) // 2]\n",
+ "\n",
+ "inputs = processor(\n",
+ " audio=sample[\"array\"],\n",
+ " sampling_rate=sample[\"sampling_rate\"],\n",
+ " text=[\"80s blues track with groovy saxophone\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
+ "\n",
+ "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc",
+ "metadata": {
+ "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc"
+ },
+ "source": [
+ "To demonstrate batched audio-prompted generation, we'll slice our sample audio by two different proportions to give two audio samples of different length.\n",
+ "Since the input audio prompts vary in length, they will be *padded* to the length of the longest audio sample in the batch before being passed to the model.\n",
+ "\n",
+ "To recover the final audio samples, the `audio_values` generated can be post-processed to remove padding by using the processor class once again:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5495f568-51ca-439d-b47b-8b52e89b78f1",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 75
+ },
+ "id": "5495f568-51ca-439d-b47b-8b52e89b78f1",
+ "outputId": "68866570-000b-4239-bed5-75bc21e828aa"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ " Your browser does not support the audio element.\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sample = next(iter(dataset))[\"audio\"]\n",
+ "\n",
+ "# take the first quater of the audio sample\n",
+ "sample_1 = sample[\"array\"][: len(sample[\"array\"]) // 4]\n",
+ "\n",
+ "# take the first half of the audio sample\n",
+ "sample_2 = sample[\"array\"][: len(sample[\"array\"]) // 2]\n",
+ "\n",
+ "inputs = processor(\n",
+ " audio=[sample_1, sample_2],\n",
+ " sampling_rate=sample[\"sampling_rate\"],\n",
+ " text=[\"80s blues track with groovy saxophone\", \"90s rock song with loud guitars and heavy drums\"],\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\",\n",
+ ")\n",
+ "\n",
+ "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n",
+ "\n",
+ "# post-process to remove padding from the batched audio\n",
+ "audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)\n",
+ "\n",
+ "Audio(audio_values[0], rate=sampling_rate)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "viwTDmzl8ZDN",
+ "metadata": {
+ "id": "viwTDmzl8ZDN"
+ },
+ "source": [
+ "## Generation Config\n",
+ "\n",
+ "The default parameters that control the generation process, such as sampling, guidance scale and number of generated tokens, can be found in the model's generation config, and updated as desired. Let's first inspect the default generation config:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0zM4notb8Y1g",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0zM4notb8Y1g",
+ "outputId": "576755e5-42b9-48bc-be13-76322837fea0"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "GenerationConfig {\n",
+ " \"_from_model_config\": true,\n",
+ " \"bos_token_id\": 2048,\n",
+ " \"decoder_start_token_id\": 2048,\n",
+ " \"do_sample\": true,\n",
+ " \"guidance_scale\": 3.0,\n",
+ " \"max_length\": 1500,\n",
+ " \"pad_token_id\": 2048,\n",
+ " \"transformers_version\": \"4.34.0.dev0\"\n",
+ "}"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.generation_config"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "DLSnSwau8jyW",
+ "metadata": {
+ "id": "DLSnSwau8jyW"
+ },
+ "source": [
+ "Alright! We see that the model defaults to using sampling mode (`do_sample=True`), a guidance scale of 3, and a maximum generation length of 1500 (which is equivalent to 30s of audio). You can update any of these attributes to change the default generation parameters:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ensSj1IB81dA",
+ "metadata": {
+ "id": "ensSj1IB81dA"
+ },
+ "outputs": [],
+ "source": [
+ "# increase the guidance scale to 4.0\n",
+ "model.generation_config.guidance_scale = 4.0\n",
+ "\n",
+ "# set the max new tokens to 256\n",
+ "model.generation_config.max_new_tokens = 256\n",
+ "\n",
+ "# set the softmax sampling temperature to 1.5\n",
+ "model.generation_config.temperature = 1.5"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "UjqGnfc-9ZFJ",
+ "metadata": {
+ "id": "UjqGnfc-9ZFJ"
+ },
+ "source": [
+ "Re-running generation now will use the newly defined values in the generation config:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "KAExrhDl9YvS",
+ "metadata": {
+ "id": "KAExrhDl9YvS"
+ },
+ "outputs": [],
+ "source": [
+ "audio_values = model.generate(**inputs.to(device))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "HdGdoGAs84hS",
+ "metadata": {
+ "id": "HdGdoGAs84hS"
+ },
+ "source": [
+ "Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting `do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the generation config."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "s__neSDH89q0",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "s__neSDH89q0",
+ "outputId": "8ec39d3c-ac4e-4cce-ee89-1e936a453859"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 1, 642560])"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audio_values.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "EOMbP5f-imWD",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "id": "EOMbP5f-imWD",
+ "outputId": "df158b75-bcc9-4304-db5c-ddb0d1490c49"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "'import os\\nimport torchaudio\\nimport numpy as np\\nimport torch\\nfrom tqdm import tqdm'"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\"\"\"import os\n",
+ "import torchaudio\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from tqdm import tqdm\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "knVkOKRxhMAJ",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 192
+ },
+ "id": "knVkOKRxhMAJ",
+ "outputId": "6b6d0063-70bf-4bd8-822a-5c2f2b6363dc"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
+ "text/plain": [
+ "'from datasets import load_dataset, Audio\\nfrom transformers import EncodecModel, AutoProcessor\\n\\n\\n# load a demonstration datasets\\nlibrispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\\n\\n# load the model + processor (for pre-processing the audio)\\nmodel = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\\nprocessor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\\n\\npath = \\'/content/Stim_Test_Run01_01_rock.wav\\'\\naudio_loaded, sr = torchaudio.load(path)\\nprint(audio_loaded.shape)\\naudio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\\n#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\\nprint(audio_loaded.shape)\\n# cast the audio data to the correct sampling rate for the model\\n#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\\n#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\\n\\n# pre-process the inputs\\nprint(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\\ninputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\\n\\n# explicitly encode then decode the audio inputs\\nencoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\\naudio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\\n\\n# or the equivalent with a forward pass\\naudio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\\n'"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "\"\"\"from datasets import load_dataset, Audio\n",
+ "from transformers import EncodecModel, AutoProcessor\n",
+ "\n",
+ "\n",
+ "# load a demonstration datasets\n",
+ "librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
+ "\n",
+ "# load the model + processor (for pre-processing the audio)\n",
+ "model = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\n",
+ "processor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\n",
+ "\n",
+ "path = '/content/Stim_Test_Run01_01_rock.wav'\n",
+ "audio_loaded, sr = torchaudio.load(path)\n",
+ "print(audio_loaded.shape)\n",
+ "audio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\n",
+ "#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\n",
+ "print(audio_loaded.shape)\n",
+ "# cast the audio data to the correct sampling rate for the model\n",
+ "#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n",
+ "#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\n",
+ "\n",
+ "# pre-process the inputs\n",
+ "print(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\n",
+ "inputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\n",
+ "\n",
+ "# explicitly encode then decode the audio inputs\n",
+ "encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n",
+ "audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\n",
+ "\n",
+ "# or the equivalent with a forward pass\n",
+ "audio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\n",
+ "\"\"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "FZ-JP_DFsD_6",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "FZ-JP_DFsD_6",
+ "outputId": "e462f0f5-8267-449f-e8b3-c11a0a7dd593"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 1, 642560])"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "audio_values.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "F-b6Q4TSjGcc",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 175
+ },
+ "id": "F-b6Q4TSjGcc",
+ "outputId": "cfe3e974-ca58-48de-833f-12b1d8e4e1b8"
+ },
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "ignored",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m| \u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maudio_codes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m: name 'encoder_outputs' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "encoder_outputs.audio_codes.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "Wv4xZ1gYlU6S",
+ "metadata": {
+ "id": "Wv4xZ1gYlU6S"
+ },
+ "outputs": [],
+ "source": [
+ "encoder_outputs.audio_codes.shape"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "013191f7a16c49fdbfc9b6cb3b0aa089": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1c58ed64d7144bcf9c58fe0f89364d61",
+ "IPY_MODEL_ce32b8aaae2f4cd38d5ec78fefaa34ce",
+ "IPY_MODEL_7ed4f572d3534679a3e1e4d90880bd71"
+ ],
+ "layout": "IPY_MODEL_fd741db4538f493588106d753a747593"
+ }
+ },
+ "09bbff09eaa44cdf92612c7f02f05f63": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "0aa3f4c09c854d90b9357d356e8ed46b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0cd8b835ee724b659c92fbee8eb62327": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_50323b27b7ff4823b733bffa97930218",
+ "IPY_MODEL_abcf0eeed63f4331bbc9a044b2d3d65f",
+ "IPY_MODEL_2fa893516c3946729fa0eda21550f658"
+ ],
+ "layout": "IPY_MODEL_7742549b504c4e71b4d0a2119d459936"
+ }
+ },
+ "1c58ed64d7144bcf9c58fe0f89364d61": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0aa3f4c09c854d90b9357d356e8ed46b",
+ "placeholder": "​",
+ "style": "IPY_MODEL_bbf7a9706dd64f6eb24d4b46ff52bc23",
+ "value": "Downloading (…)lve/main/config.json: 100%"
+ }
+ },
+ "2fa893516c3946729fa0eda21550f658": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_410cc9979cc44f5aad48d17cf6042165",
+ "placeholder": "​",
+ "style": "IPY_MODEL_768bf7c8a965476c99693c9aa3ea89cb",
+ "value": " 2.36G/2.36G [00:22<00:00, 235MB/s]"
+ }
+ },
+ "410cc9979cc44f5aad48d17cf6042165": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4d191abe24514b6a8a5bcb7aa4dd624c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f83972e889134265aff6737844971e6f",
+ "max": 224,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_9fa6a2c3f81e4a159cf652d8a48acd37",
+ "value": 224
+ }
+ },
+ "50323b27b7ff4823b733bffa97930218": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_6081dfc2472b4097b01bc7608c100ca3",
+ "placeholder": "​",
+ "style": "IPY_MODEL_a810f25c280d48539d7382064588efd7",
+ "value": "Downloading model.safetensors: 100%"
+ }
+ },
+ "6081dfc2472b4097b01bc7608c100ca3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "768bf7c8a965476c99693c9aa3ea89cb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7742549b504c4e71b4d0a2119d459936": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7ed4f572d3534679a3e1e4d90880bd71": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_f74afcde56c64ab389fed5bf7f5964b8",
+ "placeholder": "​",
+ "style": "IPY_MODEL_09bbff09eaa44cdf92612c7f02f05f63",
+ "value": " 7.87k/7.87k [00:00<00:00, 288kB/s]"
+ }
+ },
+ "7f52d5efbb064170ac8d8681ae92f29b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "881fd640878c420cbc14dfd5f8516953": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "92804f5fa9ef49a094ce3d54b051ff9c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9350625ba0fe43949ea335a27f8e402d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_92804f5fa9ef49a094ce3d54b051ff9c",
+ "placeholder": "​",
+ "style": "IPY_MODEL_9ac4c059958446b1af03da2fe2e0ac20",
+ "value": " 224/224 [00:00<00:00, 13.9kB/s]"
+ }
+ },
+ "9ac4c059958446b1af03da2fe2e0ac20": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9fa6a2c3f81e4a159cf652d8a48acd37": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a2a0abc232c04d8c96704d6b012227aa": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_dbcc2550fed34fc4b9d1f9b5e7465b85",
+ "placeholder": "​",
+ "style": "IPY_MODEL_aca7c88abc684793880a4619257522c4",
+ "value": "Downloading (…)neration_config.json: 100%"
+ }
+ },
+ "a43852c0c4754235a7cf3fb7221eed1d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "a810f25c280d48539d7382064588efd7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "abcf0eeed63f4331bbc9a044b2d3d65f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b93f98b1284b4cf8bc325e9c4d9a2bf1",
+ "max": 2364427288,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_a43852c0c4754235a7cf3fb7221eed1d",
+ "value": 2364427288
+ }
+ },
+ "aca7c88abc684793880a4619257522c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b93f98b1284b4cf8bc325e9c4d9a2bf1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bbf7a9706dd64f6eb24d4b46ff52bc23": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ce32b8aaae2f4cd38d5ec78fefaa34ce": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7f52d5efbb064170ac8d8681ae92f29b",
+ "max": 7866,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_881fd640878c420cbc14dfd5f8516953",
+ "value": 7866
+ }
+ },
+ "dbcc2550fed34fc4b9d1f9b5e7465b85": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e9895aee370a4d888cefa7a82cd90c00": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f74afcde56c64ab389fed5bf7f5964b8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f83972e889134265aff6737844971e6f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f938b2a799d5454f8885d5d6bba24b94": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_a2a0abc232c04d8c96704d6b012227aa",
+ "IPY_MODEL_4d191abe24514b6a8a5bcb7aa4dd624c",
+ "IPY_MODEL_9350625ba0fe43949ea335a27f8e402d"
+ ],
+ "layout": "IPY_MODEL_e9895aee370a4d888cefa7a82cd90c00"
+ }
+ },
+ "fd741db4538f493588106d753a747593": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
|