diff --git "a/src/Copy_of_MusicGen.ipynb" "b/src/Copy_of_MusicGen.ipynb" new file mode 100644--- /dev/null +++ "b/src/Copy_of_MusicGen.ipynb" @@ -0,0 +1,2880 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "70300319-d206-43ce-b3bf-3da6b079f20f", + "metadata": { + "id": "70300319-d206-43ce-b3bf-3da6b079f20f" + }, + "source": [ + "## MusicGen in 🤗 Transformers\n", + "\n", + "**by [Sanchit Gandhi](https://huggingface.co/sanchit-gandhi)**\n", + "\n", + "MusicGen is a Transformer-based model capable fo generating high-quality music samples conditioned on text descriptions or audio prompts. It was proposed in the paper [Simple and Controllable Music Generation](https://arxiv.org/abs/2306.05284) by Jade Copet et al. from Meta AI.\n", + "\n", + "The MusicGen model can be de-composed into three distinct stages:\n", + "1. The text descriptions are passed through a frozen text encoder model to obtain a sequence of hidden-state representations\n", + "2. The MusicGen decoder is then trained to predict discrete audio tokens, or *audio codes*, conditioned on these hidden-states\n", + "3. These audio tokens are then decoded using an audio compression model, such as EnCodec, to recover the audio waveform\n", + "\n", + "The pre-trained MusicGen checkpoints use Google's [t5-base](https://huggingface.co/t5-base) as the text encoder model, and [EnCodec 32kHz](https://huggingface.co/facebook/encodec_32khz) as the audio compression model. The MusicGen decoder is a pure language model architecture,\n", + "trained from scratch on the task of music generation.\n", + "\n", + "The novelty in the MusicGen model is how the audio codes are predicted. Traditionally, each codebook has to be predicted by a separate model (i.e. hierarchically) or by continuously refining the output of the Transformer model (i.e. upsampling). MusicGen uses an efficient *token interleaving pattern*, thus eliminating the need to cascade multiple models to predict a set of codebooks. Instead, it is able to generate the full set of codebooks in a single forward pass of the decoder, resulting in much faster inference.\n", + "\n", + "

\n", + " \n", + "

\n", + "\n", + "\n", + "**Figure 1:** Codebook delay pattern used by MusicGen. Figure taken from the [MusicGen paper](https://arxiv.org/abs/2306.05284).\n" + ] + }, + { + "cell_type": "markdown", + "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77", + "metadata": { + "id": "e70e6dbb-3211-4ef9-93f6-efaba764ac77" + }, + "source": [ + "## Prepare the Environment" + ] + }, + { + "cell_type": "markdown", + "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c", + "metadata": { + "id": "04d1fb09-4e19-4e82-a4fa-eea7b20bb96c" + }, + "source": [ + "Let’s make sure we’re connected to a GPU to run this notebook. To get a GPU, click `Runtime` -> `Change runtime type`, then change `Hardware accelerator` from `None` to `GPU`. We can verify that we’ve been assigned a GPU and view its specifications through the `nvidia-smi` command:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "21d38c22-bb79-495c-8aa9-09ceabb2957a", + "outputId": "bcb3fd96-cc9c-45fc-a0ce-2fc79398cedc", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Thu Sep 7 22:36:00 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 43C P0 60W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 44C P0 57W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 43C P0 58W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd", + "metadata": { + "id": "1abcac4f-06b0-41c7-b7e4-960ddd297afd" + }, + "source": [ + "We see here that we've got on Tesla T4 16GB GPU, although this may vary for you depending on GPU availablity and Colab GPU assignment.\n", + "\n", + "Next, we install the 🤗 Transformers package from the main branch, as well as 🤗 Datasets package to load audio files for audio-prompted generation:" + ] + }, + { + "cell_type": "markdown", + "id": "77ee39cc-654b-4f0e-b601-013e484c16f0", + "metadata": { + "id": "77ee39cc-654b-4f0e-b601-013e484c16f0" + }, + "source": [ + "## Load the Model\n", + "\n", + "The pre-trained MusicGen small, medium and large checkpoints can be loaded from the [pre-trained weights](https://huggingface.co/models?search=facebook/musicgen-) on the Hugging Face Hub. Change the repo id with the checkpoint size you wish to load. We'll default to the small checkpoint, which is the fastest of the three but has the lowest audio quality:" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "b0d87424-9f38-4658-ba47-2a465d52ad77", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 113, + "referenced_widgets": [ + "013191f7a16c49fdbfc9b6cb3b0aa089", + "1c58ed64d7144bcf9c58fe0f89364d61", + "ce32b8aaae2f4cd38d5ec78fefaa34ce", + "7ed4f572d3534679a3e1e4d90880bd71", + "fd741db4538f493588106d753a747593", + "0aa3f4c09c854d90b9357d356e8ed46b", + "bbf7a9706dd64f6eb24d4b46ff52bc23", + "7f52d5efbb064170ac8d8681ae92f29b", + "881fd640878c420cbc14dfd5f8516953", + "f74afcde56c64ab389fed5bf7f5964b8", + "09bbff09eaa44cdf92612c7f02f05f63", + "0cd8b835ee724b659c92fbee8eb62327", + "50323b27b7ff4823b733bffa97930218", + "abcf0eeed63f4331bbc9a044b2d3d65f", + "2fa893516c3946729fa0eda21550f658", + "7742549b504c4e71b4d0a2119d459936", + "6081dfc2472b4097b01bc7608c100ca3", + "a810f25c280d48539d7382064588efd7", + "b93f98b1284b4cf8bc325e9c4d9a2bf1", + "a43852c0c4754235a7cf3fb7221eed1d", + "410cc9979cc44f5aad48d17cf6042165", + "768bf7c8a965476c99693c9aa3ea89cb", + "f938b2a799d5454f8885d5d6bba24b94", + "a2a0abc232c04d8c96704d6b012227aa", + "4d191abe24514b6a8a5bcb7aa4dd624c", + "9350625ba0fe43949ea335a27f8e402d", + "e9895aee370a4d888cefa7a82cd90c00", + "dbcc2550fed34fc4b9d1f9b5e7465b85", + "aca7c88abc684793880a4619257522c4", + "f83972e889134265aff6737844971e6f", + "9fa6a2c3f81e4a159cf652d8a48acd37", + "92804f5fa9ef49a094ce3d54b051ff9c", + "9ac4c059958446b1af03da2fe2e0ac20" + ] + }, + "id": "b0d87424-9f38-4658-ba47-2a465d52ad77", + "outputId": "2af6a8e9-ff4c-4f2d-8aa2-89f35f0575e7", + "tags": [] + }, + "outputs": [], + "source": [ + "from transformers import MusicgenForConditionalGeneration\n", + "cache_dir = '/fsx/proj-fmri/ckadirt/b2m/cache/'\n", + "model = MusicgenForConditionalGeneration.from_pretrained(\"facebook/musicgen-small\", cache_dir=cache_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "4981d112-407c-4120-86aa-5c6a170543f7", + "metadata": { + "id": "4981d112-407c-4120-86aa-5c6a170543f7" + }, + "source": [ + "We can then place the model on our accelerator device (if available), or leave it on the CPU otherwise:" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97", + "metadata": { + "id": "9508dee8-39df-46fe-82f3-6cc2e9f21a97", + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "model.to(device);" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ba5cdeee-27d0-4834-b7dc-4403864550f7", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Thu Sep 7 22:37:19 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A100-SXM... On | 00000000:10:1C.0 Off | 0 |\n", + "| N/A 45C P0 78W / 400W | 3593MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 1 NVIDIA A100-SXM... On | 00000000:10:1D.0 Off | 0 |\n", + "| N/A 41C P0 52W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 2 NVIDIA A100-SXM... On | 00000000:20:1C.0 Off | 0 |\n", + "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 3 NVIDIA A100-SXM... On | 00000000:20:1D.0 Off | 0 |\n", + "| N/A 43C P0 60W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 4 NVIDIA A100-SXM... On | 00000000:90:1C.0 Off | 0 |\n", + "| N/A 44C P0 57W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 5 NVIDIA A100-SXM... On | 00000000:90:1D.0 Off | 0 |\n", + "| N/A 43C P0 55W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 6 NVIDIA A100-SXM... On | 00000000:A0:1C.0 Off | 0 |\n", + "| N/A 45C P0 54W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + "| 7 NVIDIA A100-SXM... On | 00000000:A0:1D.0 Off | 0 |\n", + "| N/A 43C P0 58W / 400W | 3MiB / 40960MiB | 0% Default |\n", + "| | | Disabled |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| 0 N/A N/A 742410 C ...3/envs/mindeye/bin/python 3590MiB |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547", + "metadata": { + "id": "f6e1166e-1335-4555-9ec4-223d1fbcb547" + }, + "source": [ + "## Generation\n", + "\n", + "MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly\n", + "better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default,\n", + "and can be explicitly specified by setting `do_sample=True` in the call to `MusicgenForConditionalGeneration.generate` (see below).\n", + "\n", + "### Unconditional Generation\n", + "\n", + "The inputs for unconditional (or 'null') generation can be obtained through the method `MusicgenForConditionalGeneration.get_unconditional_inputs`. We can then run auto-regressive generation using the `.generate` method, specifying `do_sample=True` to enable sampling mode:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2", + "metadata": { + "id": "fb7708e8-e4f1-4ab8-b04a-19395d78dea2", + "tags": [] + }, + "outputs": [], + "source": [ + "unconditional_inputs = model.get_unconditional_inputs(num_samples=1)\n", + "\n", + "audio_values = model.generate(**unconditional_inputs, do_sample=True, max_new_tokens=256)" + ] + }, + { + "cell_type": "markdown", + "id": "94cb74df-c194-4d2e-930a-12473b08a919", + "metadata": { + "id": "94cb74df-c194-4d2e-930a-12473b08a919" + }, + "source": [ + "The audio outputs are a three-dimensional Torch tensor of shape `(batch_size, num_channels, sequence_length)`. To listen\n", + "to the generated audio samples, you can either play them in an ipynb notebook:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "15f0bc7c-b899-4e7a-943e-594e73f080ea", + "outputId": "9a47adc3-17b1-40d6-989d-73319c9ea7ee", + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from IPython.display import Audio\n", + "\n", + "sampling_rate = model.config.audio_encoder.sampling_rate\n", + "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" + ] + }, + { + "cell_type": "markdown", + "id": "6de58334-40f7-4924-addb-2d6ff34c0590", + "metadata": { + "id": "6de58334-40f7-4924-addb-2d6ff34c0590" + }, + "source": [ + "Or save them as a `.wav` file using a third-party library, e.g. `scipy` (note here that we also need to remove the channel dimension from our audio tensor):" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "04291f52-0a75-4ddb-9eff-e853d0f17288", + "metadata": { + "id": "04291f52-0a75-4ddb-9eff-e853d0f17288" + }, + "outputs": [], + "source": [ + "import scipy\n", + "\n", + "scipy.io.wavfile.write(\"musicgen_out.wav\", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())" + ] + }, + { + "cell_type": "markdown", + "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39", + "metadata": { + "id": "e52ff5b2-c170-4079-93a4-a02acbdaeb39" + }, + "source": [ + "The argument `max_new_tokens` specifies the number of new tokens to generate. As a rule of thumb, you can work out the length of the generated audio sample in seconds by using the frame rate of the EnCodec model:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d75ad107-e19b-47f3-9cf1-5102ab4ae74a", + "outputId": "c021297b-837b-4d35-e01b-34a66a33c1dd", + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "5.12" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "audio_length_in_s = 256 / model.config.audio_encoder.frame_rate\n", + "\n", + "audio_length_in_s" + ] + }, + { + "cell_type": "markdown", + "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581", + "metadata": { + "id": "9a0e999b-2595-4090-8e1a-acfaa42d2581" + }, + "source": [ + "### Text-Conditional Generation\n", + "\n", + "The model can generate an audio sample conditioned on a text prompt through use of the `MusicgenProcessor` to pre-process\n", + "the inputs. The pre-processed inputs can then be passed to the `.generate` method to generate text-conditional audio samples.\n", + "Again, we enable sampling mode by setting `do_sample=True`:" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "5fba4154-13f6-403a-958b-101d6eacfb6e", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "5fba4154-13f6-403a-958b-101d6eacfb6e", + "outputId": "a4b87d4b-1db0-49af-dac0-eb600c47cbfb", + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoProcessor\n", + "from einops import rearrange\n", + "\n", + "processor = AutoProcessor.from_pretrained(\"facebook/musicgen-small\")\n", + "\n", + "inputs = processor(\n", + " text=[\"hiphop beat\", \"90s rock song with loud guitars and heavy drums\"],\n", + " padding=True,\n", + " return_tensors=\"pt\",\n", + ")\n", + "\n", + "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n", + "\n", + "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "_FZb_Zo-Dajl", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_FZb_Zo-Dajl", + "outputId": "9834a745-ac10-4142-e22f-6998997304e5", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 2, 4, 253])\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " tokens_train = model.audio_encoder(audio_values)\n", + " print(tokens_train.audio_codes.shape)\n", + " tokens_on_format = rearrange(tokens_train.audio_codes, \"n b c l -> (n b c) l\")\n", + " # make a copy of the tokens excluding to the computation graph\n", + " tokens_on_format = tokens_on_format.detach().clone()" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "VHyalHz78TlY", + "metadata": { + "id": "VHyalHz78TlY", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 13, 768])\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " encoder_hidden_states = model.text_encoder(inputs.input_ids, attention_mask = inputs.attention_mask.detach().clone()).last_hidden_state.detach().clone()\n", + " print(encoder_hidden_states.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "dOnCaS3F9yfz", + "metadata": { + "id": "dOnCaS3F9yfz", + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2, 13, 1024])\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states).detach().clone()\n", + " print(encoder_hidden_states.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "z56a5Yjqoum5", + "metadata": { + "id": "z56a5Yjqoum5", + "tags": [] + }, + "outputs": [], + "source": [ + "pad_token_id = model.generation_config.pad_token_id\n", + "decoder_input_ids = (\n", + " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n", + " * pad_token_id\n", + ").to(device).detach().clone()\n", + "\n", + "inputs.attention_mask = inputs.attention_mask.detach().clone()" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "xLrDtoYhqBiP", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xLrDtoYhqBiP", + "outputId": "78909eb6-af29-4655-d473-86359093c9d7", + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8, 1])" + ] + }, + "execution_count": 67, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "decoder_input_ids.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "b67b9b42-fd04-4b0c-ac36-a5b7016ada16", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8, 253])" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens_on_format.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "hlKZ7dJgH1VJ", + "metadata": { + "id": "hlKZ7dJgH1VJ", + "tags": [] + }, + "outputs": [], + "source": [ + "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "id": "62b31feb-e392-4b5c-9855-25d97364256b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8, 252, 2048])" + ] + }, + "execution_count": 72, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_manual.logits.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "a5fa2fd1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2008, 2048]) torch.Size([2008])\n" + ] + } + ], + "source": [ + "def training_step(labels, results_manual):\n", + " # as the labels are the same than the input, we should remove the first column\n", + " labels = labels[:, 1:]\n", + " # as the logits are the same than the input, we should remove the last column\n", + " results_manual.logits = results_manual.logits[:, :-1, :]\n", + " loss_fct = torch.nn.CrossEntropyLoss()\n", + " print(rearrange(results_manual.logits, \"c t v -> (c t) v\").shape, rearrange(labels, \"c t -> (c t)\").shape)\n", + " loss = loss_fct(rearrange(results_manual.logits, \"c t v -> (c t) v\"), rearrange(labels, \"c t -> (c t)\"))\n", + " return loss\n", + "\n", + "loss = training_step(tokens_on_format_test, results_manual)" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "44f38b39-c06b-4e50-b233-ef6b6f829de5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(10.8952, device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "4022c586-423e-489e-9647-c0c24a378092", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model.zero_grad()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "06335817-59a3-47b2-a2df-10d71acdc3c4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "tokens_on_format_test_tt = tokens_on_format_test.detach()" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "82367433", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3841, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(6.5876, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(5.2953, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(3.9745, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(3.5427, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(3.1808, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.9803, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.8768, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.7334, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.6114, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.5443, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.4862, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.4340, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.3890, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.3461, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.3014, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.2524, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.2108, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.1693, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.1228, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.0805, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.0398, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(2.0003, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.9617, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.9265, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.8898, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.8527, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.8144, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.7804, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.7465, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.7152, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.6817, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.6494, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.6182, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5891, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5631, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5393, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5514, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5996, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5329, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5090, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.4846, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.4633, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.4849, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3862, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.4540, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3665, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3937, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3255, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3280, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2869, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2772, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2436, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2343, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2054, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1852, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1706, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1331, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1386, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1027, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0897, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0732, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1328, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.4694, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3952, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2203, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2310, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1787, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.2025, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1772, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1184, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0861, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0596, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0484, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0140, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0278, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9599, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9769, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9378, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9275, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9106, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.8895, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.8961, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9111, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1137, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5037, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.5237, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3981, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1750, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.3016, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1642, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.1759, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0935, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0816, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0367, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(1.0183, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9768, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9370, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.9176, device='cuda:0', grad_fn=)\n", + "torch.Size([2008, 2048]) torch.Size([2008])\n", + "tensor(0.8828, device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "# move all the stuff to the CPU\n", + "#encoder_hidden_states = encoder_hidden_states.cpu()\n", + "#inputs.attention_mask = inputs.attention_mask.cpu()\n", + "#decoder_attention_mask = decoder_attention_mask.cpu()\n", + "#tokens_on_format_test = tokens_on_format_test.cpu()\n", + "#results_manual.logits = results_manual.logits.cpu()\n", + "#model = model.cpu()\n", + "\n", + "\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)\n", + "for i in range(100):\n", + " optimizer.zero_grad()\n", + " results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)\n", + " loss = training_step(tokens_on_format[:,:-1].detach().clone(), results_manual)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " \n", + " print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "c9870eac-70ca-4ff5-a4cf-2138bef89de1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 578, 68, 244, 753, 135, 135, 40, 244, 1358, 83, 1265, 91,\n", + " 555, 816, 236, 235, 235, 244, 135, 135, 609, 244, 609, 135,\n", + " 1081, 480, 549, 14, 327, 613, 244, 244, 83, 135, 66, 244,\n", + " 472, 403, 472, 68, 1757, 753, 1348, 289, 609, 289, 609, 289,\n", + " 609, 539, 578, 472, 83, 609, 66, 609, 66, 609, 289, 609,\n", + " 289, 289, 403, 570, 23, 116, 14, 376, 609, 1757, 135, 1358,\n", + " 83, 40, 1205, 578, 444, 68, 135, 1358, 289, 83, 289, 1265,\n", + " 1757, 289, 289, 1081, 1036, 68, 1453, 1265, 1265, 1265, 289, 8,\n", + " 376, 376, 289, 778, 403, 444, 68, 1265, 1757, 1453, 376, 376,\n", + " 376, 289, 609, 1265, 539, 434, 334, 1986, 83, 244, 235, 244,\n", + " 1265, 609, 244, 1265, 1453, 172, 1036, 68, 83, 1453, 1453, 609,\n", + " 1453, 1265, 1265, 1265, 1265, 1205, 555, 588, 116, 235, 244, 244,\n", + " 172, 275, 68, 376, 753, 135, 376, 986, 1348, 609, 8, 609,\n", + " 609, 609, 609, 289, 609, 609, 778, 403, 1360, 68, 1265, 21,\n", + " 1358, 1453, 40, 1453, 289, 1453, 289, 91, 578, 226, 1265, 1265,\n", + " 376, 1265, 289, 1453, 289, 244, 135, 244, 403, 1036, 425, 609,\n", + " 244, 1043, 244, 1348, 135, 1348, 244, 1043, 1488, 403, 425, 40,\n", + " 83, 244, 254, 40, 135, 1265, 135, 40, 244, 979, 759, 834,\n", + " 327, 135, 289, 235, 235, 1358, 327, 289, 244, 778, 403, 1953,\n", + " 1617, 244, 83, 14, 135, 1358, 244, 40, 135, 83, 539, 578,\n", + " 68, 83, 83, 1453, 1265, 135, 40, 289, 1453, 609, 1453, 403,\n", + " 275], device='cuda:0')" + ] + }, + "execution_count": 98, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokens_on_format[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "id": "61b8b0e2-8185-4801-9ea3-8b05820ed062", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "results_manual = model.decoder(encoder_hidden_states = encoder_hidden_states * inputs.attention_mask[..., None] , input_ids = tokens_on_format[:,:-1], encoder_attention_mask = inputs.attention_mask)" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "2a1cfaa7-707c-4ed1-8a74-14d595a2d7bf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 135, 244, 235, 609, 1265, 40, 1358, 68, 83, 816, 236, 555,\n", + " 549, 753, 91, 1081, 480, 14, 327, 613, 116, 23, 1757, 1379,\n", + " 21, 289, 588, 10, 376, 2007, 148, 146, 789, 1602, 425, 429,\n", + " 570, 1132, 172, 1522, 403, 1488, 1846, 1043, 8, 66, 902, 254,\n", + " 539, 1106, 472, 374, 1453, 1348, 1671, 226, 778, 1195, 1360, 275,\n", + " 1334, 759, 1875, 949, 993, 1194, 444, 1796, 97, 1676, 1187, 1304,\n", + " 1953, 492, 1698, 578, 979, 1036, 1775, 434, 1210, 1617, 15, 149,\n", + " 276, 1487, 1540, 986, 1398, 721, 243, 4, 1748, 131, 1986, 1544,\n", + " 223, 227, 1109, 1815], device='cuda:0')\n" + ] + } + ], + "source": [ + "max_idx = torch.topk(results_manual.logits[0][9], k=100)[1]\n", + "print(max_idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "UV1XbuCRONUA", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UV1XbuCRONUA", + "outputId": "0b846e0f-180d-4bd6-946f-fb29665479d1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1, 2048])\n", + "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n", + " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "print(results_manual.logits.shape)\n", + "#results_manual.past_key_values.shape\n", + "print(results_manual.logits[1][0][0:10])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "YpFaaCefZ4mp", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YpFaaCefZ4mp", + "outputId": "c49266f8-4602-47ae-94c9-de472661ba5b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2048\n" + ] + } + ], + "source": [ + "inputs = processor(\n", + " text=[\"michael jackson signing pop\", \"90s rock song with loud guitars and heavy drums\"],\n", + " padding=True,\n", + " return_tensors=\"pt\",\n", + ").to('cuda')\n", + "\n", + "pad_token_id = model.generation_config.pad_token_id\n", + "print(pad_token_id)\n", + "decoder_input_ids = (\n", + " torch.ones((inputs.input_ids.shape[0] * model.decoder.num_codebooks, 1), dtype=torch.long)\n", + " * pad_token_id\n", + ").to('cuda')\n", + "\n", + "end = model(**inputs, decoder_input_ids=decoder_input_ids)\n", + "#logits.shape # (bsz * num_codebooks, tgt_len, vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "Fdr_sVdLQF3m", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Fdr_sVdLQF3m", + "outputId": "d4f55f93-e76f-4a90-9253-b6fb51a9add5" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': tensor([[2278, 9, 15, 40, 3, 9325, 739, 8097, 2783, 1, 0, 0,\n", + " 0],\n", + " [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7,\n", + " 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "FyTh1KykP1OT", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FyTh1KykP1OT", + "outputId": "036edf58-44b1-40b8-cbad-137ddad5be9b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1, 2048])\n", + "tensor([-0.1700, -3.6208, -0.9766, -1.1846, -1.3526, 1.4435, 2.6102, -2.6462,\n", + " -1.3472, -1.6042], device='cuda:0', grad_fn=)\n" + ] + } + ], + "source": [ + "print(end.logits.shape)\n", + "print(end.logits[1][0][0:10])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "h_0zNWeSIXxC", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h_0zNWeSIXxC", + "outputId": "a7a3fe57-ceca-4e35-b8a6-7c870fcfdb54" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2048],\n", + " [2048],\n", + " [2048],\n", + " [2048],\n", + " [2048],\n", + " [2048],\n", + " [2048],\n", + " [2048]], device='cuda:0')" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(end.logits[1][1][1:10])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "-1y5EZojYzDz", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-1y5EZojYzDz", + "outputId": "5db3d81c-0a18-4386-92c7-b20354e2c902" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "odict_keys(['logits', 'past_key_values', 'encoder_last_hidden_state'])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "end.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "tL-0dkVkbeka", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tL-0dkVkbeka", + "outputId": "2838c153-9344-4ebc-96fb-98331e351e72" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "24" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(end.past_key_values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2X62iV8iVNRU", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "2X62iV8iVNRU", + "outputId": "6b97e3ff-f8b7-40a1-9139-d473354b115c" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2xlbMjFTTUBd", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2xlbMjFTTUBd", + "outputId": "c0228562-0a8e-4198-9171-027a5d49e8fc" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 2775, 7, 2783, 1463, 28, 7981, 63, 5253, 7, 11,\n", + " 13353, 1, 0],\n", + " [ 2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437,\n", + " 5253, 7, 1]], device='cuda:0')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs.input_ids" + ] + }, + { + "cell_type": "markdown", + "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0", + "metadata": { + "id": "4851a94c-ae02-41c9-b1dd-c1422ba34dc0" + }, + "source": [ + "The `guidance_scale` is used in classifier free guidance (CFG), setting the weighting between the conditional logits\n", + "(which are predicted from the text prompts) and the unconditional logits (which are predicted from an unconditional or\n", + "'null' prompt). A higher guidance scale encourages the model to generate samples that are more closely linked to the input\n", + "prompt, usually at the expense of poorer audio quality. CFG is enabled by setting `guidance_scale > 1`. For best results,\n", + "use a `guidance_scale=3` (default) for text and audio-conditional generation." + ] + }, + { + "cell_type": "markdown", + "id": "d391b2a1-6376-4b69-b562-4388b731cf60", + "metadata": { + "id": "d391b2a1-6376-4b69-b562-4388b731cf60" + }, + "source": [ + "### Audio-Prompted Generation\n", + "\n", + "The same `MusicgenProcessor` can be used to pre-process an audio prompt that is used for audio continuation. In the\n", + "following example, we load an audio file using the 🤗 Datasets library, pre-process it using the processor class,\n", + "and then forward the inputs to the model for generation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "56a5c28a-f6c1-4ac8-ae08-6776a2b2c5b8", + "outputId": "81c95bfd-649d-424f-a764-0f1b881f37dd" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"sanchit-gandhi/gtzan\", split=\"train\", streaming=True)\n", + "sample = next(iter(dataset))[\"audio\"]\n", + "\n", + "# take the first half of the audio sample\n", + "sample[\"array\"] = sample[\"array\"][: len(sample[\"array\"]) // 2]\n", + "\n", + "inputs = processor(\n", + " audio=sample[\"array\"],\n", + " sampling_rate=sample[\"sampling_rate\"],\n", + " text=[\"80s blues track with groovy saxophone\"],\n", + " padding=True,\n", + " return_tensors=\"pt\",\n", + ")\n", + "\n", + "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n", + "\n", + "Audio(audio_values[0].cpu().numpy(), rate=sampling_rate)" + ] + }, + { + "cell_type": "markdown", + "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc", + "metadata": { + "id": "77518aa4-1b9b-4af6-b5ac-8ecdcb79b4cc" + }, + "source": [ + "To demonstrate batched audio-prompted generation, we'll slice our sample audio by two different proportions to give two audio samples of different length.\n", + "Since the input audio prompts vary in length, they will be *padded* to the length of the longest audio sample in the batch before being passed to the model.\n", + "\n", + "To recover the final audio samples, the `audio_values` generated can be post-processed to remove padding by using the processor class once again:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5495f568-51ca-439d-b47b-8b52e89b78f1", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "id": "5495f568-51ca-439d-b47b-8b52e89b78f1", + "outputId": "68866570-000b-4239-bed5-75bc21e828aa" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample = next(iter(dataset))[\"audio\"]\n", + "\n", + "# take the first quater of the audio sample\n", + "sample_1 = sample[\"array\"][: len(sample[\"array\"]) // 4]\n", + "\n", + "# take the first half of the audio sample\n", + "sample_2 = sample[\"array\"][: len(sample[\"array\"]) // 2]\n", + "\n", + "inputs = processor(\n", + " audio=[sample_1, sample_2],\n", + " sampling_rate=sample[\"sampling_rate\"],\n", + " text=[\"80s blues track with groovy saxophone\", \"90s rock song with loud guitars and heavy drums\"],\n", + " padding=True,\n", + " return_tensors=\"pt\",\n", + ")\n", + "\n", + "audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)\n", + "\n", + "# post-process to remove padding from the batched audio\n", + "audio_values = processor.batch_decode(audio_values, padding_mask=inputs.padding_mask)\n", + "\n", + "Audio(audio_values[0], rate=sampling_rate)" + ] + }, + { + "cell_type": "markdown", + "id": "viwTDmzl8ZDN", + "metadata": { + "id": "viwTDmzl8ZDN" + }, + "source": [ + "## Generation Config\n", + "\n", + "The default parameters that control the generation process, such as sampling, guidance scale and number of generated tokens, can be found in the model's generation config, and updated as desired. Let's first inspect the default generation config:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0zM4notb8Y1g", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0zM4notb8Y1g", + "outputId": "576755e5-42b9-48bc-be13-76322837fea0" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "GenerationConfig {\n", + " \"_from_model_config\": true,\n", + " \"bos_token_id\": 2048,\n", + " \"decoder_start_token_id\": 2048,\n", + " \"do_sample\": true,\n", + " \"guidance_scale\": 3.0,\n", + " \"max_length\": 1500,\n", + " \"pad_token_id\": 2048,\n", + " \"transformers_version\": \"4.34.0.dev0\"\n", + "}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.generation_config" + ] + }, + { + "cell_type": "markdown", + "id": "DLSnSwau8jyW", + "metadata": { + "id": "DLSnSwau8jyW" + }, + "source": [ + "Alright! We see that the model defaults to using sampling mode (`do_sample=True`), a guidance scale of 3, and a maximum generation length of 1500 (which is equivalent to 30s of audio). You can update any of these attributes to change the default generation parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ensSj1IB81dA", + "metadata": { + "id": "ensSj1IB81dA" + }, + "outputs": [], + "source": [ + "# increase the guidance scale to 4.0\n", + "model.generation_config.guidance_scale = 4.0\n", + "\n", + "# set the max new tokens to 256\n", + "model.generation_config.max_new_tokens = 256\n", + "\n", + "# set the softmax sampling temperature to 1.5\n", + "model.generation_config.temperature = 1.5" + ] + }, + { + "cell_type": "markdown", + "id": "UjqGnfc-9ZFJ", + "metadata": { + "id": "UjqGnfc-9ZFJ" + }, + "source": [ + "Re-running generation now will use the newly defined values in the generation config:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "KAExrhDl9YvS", + "metadata": { + "id": "KAExrhDl9YvS" + }, + "outputs": [], + "source": [ + "audio_values = model.generate(**inputs.to(device))" + ] + }, + { + "cell_type": "markdown", + "id": "HdGdoGAs84hS", + "metadata": { + "id": "HdGdoGAs84hS" + }, + "source": [ + "Note that any arguments passed to the generate method will **supersede** those in the generation config, so setting `do_sample=False` in the call to generate will supersede the setting of `model.generation_config.do_sample` in the generation config." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "s__neSDH89q0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "s__neSDH89q0", + "outputId": "8ec39d3c-ac4e-4cce-ee89-1e936a453859" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 1, 642560])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "audio_values.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "EOMbP5f-imWD", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "EOMbP5f-imWD", + "outputId": "df158b75-bcc9-4304-db5c-ddb0d1490c49" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'import os\\nimport torchaudio\\nimport numpy as np\\nimport torch\\nfrom tqdm import tqdm'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\"import os\n", + "import torchaudio\n", + "import numpy as np\n", + "import torch\n", + "from tqdm import tqdm\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "knVkOKRxhMAJ", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 192 + }, + "id": "knVkOKRxhMAJ", + "outputId": "6b6d0063-70bf-4bd8-822a-5c2f2b6363dc" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'from datasets import load_dataset, Audio\\nfrom transformers import EncodecModel, AutoProcessor\\n\\n\\n# load a demonstration datasets\\nlibrispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\\n\\n# load the model + processor (for pre-processing the audio)\\nmodel = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\\nprocessor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\\n\\npath = \\'/content/Stim_Test_Run01_01_rock.wav\\'\\naudio_loaded, sr = torchaudio.load(path)\\nprint(audio_loaded.shape)\\naudio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\\n#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\\nprint(audio_loaded.shape)\\n# cast the audio data to the correct sampling rate for the model\\n#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\\n#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\\n\\n# pre-process the inputs\\nprint(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\\ninputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\\n\\n# explicitly encode then decode the audio inputs\\nencoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\\naudio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\\n\\n# or the equivalent with a forward pass\\naudio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\\n'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\"from datasets import load_dataset, Audio\n", + "from transformers import EncodecModel, AutoProcessor\n", + "\n", + "\n", + "# load a demonstration datasets\n", + "librispeech_dummy = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n", + "\n", + "# load the model + processor (for pre-processing the audio)\n", + "model = EncodecModel.from_pretrained(\"facebook/encodec_48khz\")\n", + "processor = AutoProcessor.from_pretrained(\"facebook/encodec_48khz\")\n", + "\n", + "path = '/content/Stim_Test_Run01_01_rock.wav'\n", + "audio_loaded, sr = torchaudio.load(path)\n", + "print(audio_loaded.shape)\n", + "audio_loaded = torchaudio.transforms.Resample(sr, 24000)(audio_loaded)\n", + "#audio_sample = processor(raw_audio=audio_loaded[0], sampling_rate=32000, return_tensors=\"pt\")\n", + "print(audio_loaded.shape)\n", + "# cast the audio data to the correct sampling rate for the model\n", + "#librispeech_dummy = librispeech_dummy.cast_column(\"audio\", Audio(sampling_rate=processor.sampling_rate))\n", + "#audio_sample = librispeech_dummy[0][\"audio\"][\"array\"]\n", + "\n", + "# pre-process the inputs\n", + "print(torch.cat([audio_loaded, audio_loaded], dim = 0).shape)\n", + "inputs = processor(raw_audio=torch.cat([audio_loaded, audio_loaded], dim = 0).unsqueeze(dim = -1).transpose(0,2), sampling_rate=processor.sampling_rate, return_tensors=\"pt\")\n", + "\n", + "# explicitly encode then decode the audio inputs\n", + "encoder_outputs = model.encode(inputs[\"input_values\"], inputs[\"padding_mask\"])\n", + "audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs[\"padding_mask\"])[0]\n", + "\n", + "# or the equivalent with a forward pass\n", + "audio_values = model(inputs[\"input_values\"], inputs[\"padding_mask\"]).audio_values\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "FZ-JP_DFsD_6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FZ-JP_DFsD_6", + "outputId": "e462f0f5-8267-449f-e8b3-c11a0a7dd593" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 1, 642560])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "audio_values.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "F-b6Q4TSjGcc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 175 + }, + "id": "F-b6Q4TSjGcc", + "outputId": "cfe3e974-ca58-48de-833f-12b1d8e4e1b8" + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "ignored", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mencoder_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maudio_codes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'encoder_outputs' is not defined" + ] + } + ], + "source": [ + "encoder_outputs.audio_codes.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "Wv4xZ1gYlU6S", + "metadata": { + "id": "Wv4xZ1gYlU6S" + }, + "outputs": [], + "source": [ + "encoder_outputs.audio_codes.shape" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "013191f7a16c49fdbfc9b6cb3b0aa089": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_1c58ed64d7144bcf9c58fe0f89364d61", + "IPY_MODEL_ce32b8aaae2f4cd38d5ec78fefaa34ce", + "IPY_MODEL_7ed4f572d3534679a3e1e4d90880bd71" + ], + "layout": "IPY_MODEL_fd741db4538f493588106d753a747593" + } + }, + "09bbff09eaa44cdf92612c7f02f05f63": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "0aa3f4c09c854d90b9357d356e8ed46b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "0cd8b835ee724b659c92fbee8eb62327": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_50323b27b7ff4823b733bffa97930218", + "IPY_MODEL_abcf0eeed63f4331bbc9a044b2d3d65f", + "IPY_MODEL_2fa893516c3946729fa0eda21550f658" + ], + "layout": "IPY_MODEL_7742549b504c4e71b4d0a2119d459936" + } + }, + "1c58ed64d7144bcf9c58fe0f89364d61": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_0aa3f4c09c854d90b9357d356e8ed46b", + "placeholder": "​", + "style": "IPY_MODEL_bbf7a9706dd64f6eb24d4b46ff52bc23", + "value": "Downloading (…)lve/main/config.json: 100%" + } + }, + "2fa893516c3946729fa0eda21550f658": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_410cc9979cc44f5aad48d17cf6042165", + "placeholder": "​", + "style": "IPY_MODEL_768bf7c8a965476c99693c9aa3ea89cb", + "value": " 2.36G/2.36G [00:22<00:00, 235MB/s]" + } + }, + "410cc9979cc44f5aad48d17cf6042165": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4d191abe24514b6a8a5bcb7aa4dd624c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f83972e889134265aff6737844971e6f", + "max": 224, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_9fa6a2c3f81e4a159cf652d8a48acd37", + "value": 224 + } + }, + "50323b27b7ff4823b733bffa97930218": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_6081dfc2472b4097b01bc7608c100ca3", + "placeholder": "​", + "style": "IPY_MODEL_a810f25c280d48539d7382064588efd7", + "value": "Downloading model.safetensors: 100%" + } + }, + "6081dfc2472b4097b01bc7608c100ca3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "768bf7c8a965476c99693c9aa3ea89cb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "7742549b504c4e71b4d0a2119d459936": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7ed4f572d3534679a3e1e4d90880bd71": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f74afcde56c64ab389fed5bf7f5964b8", + "placeholder": "​", + "style": "IPY_MODEL_09bbff09eaa44cdf92612c7f02f05f63", + "value": " 7.87k/7.87k [00:00<00:00, 288kB/s]" + } + }, + "7f52d5efbb064170ac8d8681ae92f29b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "881fd640878c420cbc14dfd5f8516953": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "92804f5fa9ef49a094ce3d54b051ff9c": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9350625ba0fe43949ea335a27f8e402d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_92804f5fa9ef49a094ce3d54b051ff9c", + "placeholder": "​", + "style": "IPY_MODEL_9ac4c059958446b1af03da2fe2e0ac20", + "value": " 224/224 [00:00<00:00, 13.9kB/s]" + } + }, + "9ac4c059958446b1af03da2fe2e0ac20": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9fa6a2c3f81e4a159cf652d8a48acd37": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a2a0abc232c04d8c96704d6b012227aa": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dbcc2550fed34fc4b9d1f9b5e7465b85", + "placeholder": "​", + "style": "IPY_MODEL_aca7c88abc684793880a4619257522c4", + "value": "Downloading (…)neration_config.json: 100%" + } + }, + "a43852c0c4754235a7cf3fb7221eed1d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "a810f25c280d48539d7382064588efd7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "abcf0eeed63f4331bbc9a044b2d3d65f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b93f98b1284b4cf8bc325e9c4d9a2bf1", + "max": 2364427288, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a43852c0c4754235a7cf3fb7221eed1d", + "value": 2364427288 + } + }, + "aca7c88abc684793880a4619257522c4": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b93f98b1284b4cf8bc325e9c4d9a2bf1": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bbf7a9706dd64f6eb24d4b46ff52bc23": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "ce32b8aaae2f4cd38d5ec78fefaa34ce": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7f52d5efbb064170ac8d8681ae92f29b", + "max": 7866, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_881fd640878c420cbc14dfd5f8516953", + "value": 7866 + } + }, + "dbcc2550fed34fc4b9d1f9b5e7465b85": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e9895aee370a4d888cefa7a82cd90c00": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f74afcde56c64ab389fed5bf7f5964b8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f83972e889134265aff6737844971e6f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f938b2a799d5454f8885d5d6bba24b94": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a2a0abc232c04d8c96704d6b012227aa", + "IPY_MODEL_4d191abe24514b6a8a5bcb7aa4dd624c", + "IPY_MODEL_9350625ba0fe43949ea335a27f8e402d" + ], + "layout": "IPY_MODEL_e9895aee370a4d888cefa7a82cd90c00" + } + }, + "fd741db4538f493588106d753a747593": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}