ArseniyPerchik commited on
Commit
25a1345
·
1 Parent(s): a853d77
{good_policies → agent_policies}/sac_warehouse_r_10_working_v1.zip RENAMED
File without changes
{good_policies → agent_policies}/sac_warehouse_r_20.zip RENAMED
File without changes
app.py CHANGED
@@ -1,62 +1,198 @@
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import matplotlib.pyplot as plt
4
  import matplotlib.animation as animation
5
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
- def create_animation():
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  fig, ax = plt.subplots(figsize=(7, 7))
10
- xdata, ydata = [], []
11
- ln, = plt.plot([], [], 'b-', animated=True)
 
 
 
 
 
 
 
 
12
 
13
  def init():
14
- ax.set_xlim(0, 2*np.pi)
15
- ax.set_ylim(-1.1, 1.1)
16
- return ln,
 
17
 
18
  def update(frame):
19
- xdata.append(frame)
20
- ydata.append(np.sin(frame))
21
- ln.set_data(xdata, ydata)
22
- return ln,
 
 
23
 
24
- ani = animation.FuncAnimation(
25
- fig, update, frames=np.linspace(0, 2*np.pi, 100),
26
- init_func=init, blit=True, repeat=False
27
- )
28
 
29
  # Save to MP4
30
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
31
- ani.save(temp_video.name, writer='ffmpeg', fps=20)
32
  plt.close(fig)
33
-
34
  return temp_video.name
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def load_image_on_start():
39
  return np.random.rand(700, 700)
40
- # return None
41
-
42
- with gr.Blocks() as demo:
43
- gr.Markdown("## Agent Control with Language")
44
- gr.Markdown('## Say the agent where to go and what to do')
45
-
46
- with gr.Row():
47
- with gr.Column():
48
- request_audio = gr.Audio()
49
- send_btn = gr.Button(value='Send Request')
50
- request_text = gr.Textbox(label="Request:", lines=2, interactive=False)
51
- request_target = gr.Textbox(label='Target:', lines=2)
52
- request_plan = gr.Textbox(label='Plan status:', lines=2)
53
- with gr.Column():
54
- output_env = gr.Video(label="Env:", autoplay=True)
55
-
56
- # EVENTS:
57
- # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
58
- # demo.load(fn=load_image_on_start, outputs=output_env_image)
59
- demo.load(fn=create_animation, outputs=output_env)
60
-
61
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  demo.launch()
 
1
+ # gradio app.py --watch-dirs app.py
2
+
3
  import gradio as gr
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
  import matplotlib.animation as animation
7
  import tempfile
8
+ import torch
9
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
10
+ import torchaudio
11
+ import torchaudio.transforms as T
12
+ from matplotlib.patches import Circle
13
+ from stable_baselines3 import SAC
14
+ from warehouse_env import WarehouseEnv
15
+ from types import SimpleNamespace
16
+
17
+
18
+ # ---------------------------- #
19
+ # global variables
20
+ # ---------------------------- #
21
+ # models
22
+ # a model for the automatic-speech-recognition task
23
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
24
+ # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
25
+ # model_id = "./models_for_proj/librispeech_asr_dummy"
26
+ # model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
27
+ # model.to(device)
28
+ # processor = AutoProcessor.from_pretrained(model_id)
29
+ # asr_pipe = pipeline(
30
+ # "automatic-speech-recognition",
31
+ # model=model,
32
+ # tokenizer=processor.tokenizer,
33
+ # feature_extractor=processor.feature_extractor,
34
+ # max_new_tokens=128,
35
+ # torch_dtype=torch_dtype,
36
+ # device=device,
37
+ # )
38
+ asr_pipe_default = pipeline("automatic-speech-recognition")
39
+
40
+
41
+ # env variables
42
+ rl_model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
43
+ # agent_pos = {'x': 50.0, 'y': 50.0}
44
+ agent_pos = SimpleNamespace(**{'x': 50.0, 'y': 50.0})
45
+ goal_dict = {
46
+ '1': (20, 20),
47
+ '2': (80, 20),
48
+ '3': (80, 80),
49
+ '4': (20, 80),
50
+ }
51
+ targets_x, targets_y = [], []
52
+ for k, v in goal_dict.items():
53
+ targets_x.append(v[0])
54
+ targets_y.append(v[1])
55
+ r_coverage = 10
56
+
57
 
58
 
59
+
60
+ # ---------------------------- #
61
+ # functions
62
+ # ---------------------------- #
63
+ def create_standing_animation():
64
+ path = [(agent_pos.x, agent_pos.y)]
65
+ return create_animation(path, targets_x, targets_y, r_coverage)
66
+
67
+
68
+ def create_animation(path, targets_x, targets_y, r_coverage):
69
+ # path = [(i,i) for i in range(90)]
70
+ # targets_x = [20, 80, 80, 20]
71
+ # targets_y = [20, 20, 80, 80]
72
+ # RADIUS_COVERAGE = 10
73
  fig, ax = plt.subplots(figsize=(7, 7))
74
+
75
+ # agent
76
+ ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15)
77
+
78
+ # targets
79
+ ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15)
80
+ for t_x, t_y in zip(targets_x, targets_y):
81
+ circle = Circle((t_x, t_y), r_coverage, color='orange', fill=True, alpha=0.3)
82
+ ax.add_patch(circle)
83
+ # plt.tight_layout()
84
 
85
  def init():
86
+ ax.set_xlim([0, 100])
87
+ ax.set_ylim([0, 100])
88
+ ax.set_title(f'Warehouse Env', fontweight="bold", size=10)
89
+ return ln1,
90
 
91
  def update(frame):
92
+ # for each frame, update the data stored on each artist.
93
+ x = [path[frame][0]]
94
+ y = [path[frame][1]]
95
+
96
+ ln1.set_data(x, y)
97
+ return ln1,
98
 
99
+ ani = animation.FuncAnimation(fig, update, frames=len(path),
100
+ init_func=init, blit=True, repeat=False)
101
+ # plt.show()
 
102
 
103
  # Save to MP4
104
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
105
+ ani.save(temp_video.name, writer='ffmpeg', fps=30)
106
  plt.close(fig)
 
107
  return temp_video.name
108
 
109
 
110
+ def move_agent(target_input: int):
111
+ if target_input not in goal_dict:
112
+ return create_standing_animation(), 'Did not find a target.'
113
+ # get goal locations:
114
+ goal_x, goal_y = goal_dict[target_input]
115
+ # build the path
116
+ env: WarehouseEnv = WarehouseEnv(render_mode='')
117
+ model = SAC.load(rl_model_name)
118
+ obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
119
+ path = []
120
+ while True:
121
+ action, _ = model.predict(obs)
122
+ obs, rewards, done, trunc, info = env.step(action)
123
+ path.append((env.agent_x, env.agent_y))
124
+ if done:
125
+ break
126
+ if trunc:
127
+ obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
128
+ path = []
129
+
130
+ agent_pos.x = path[-1][0]
131
+ agent_pos.y = path[-1][1]
132
+ # create animation
133
+ video_output = create_animation(path, targets_x, targets_y, r_coverage)
134
+
135
+ # update status
136
+ status = f'Went to target {target_input}.'
137
+
138
+ return video_output, status
139
+
140
 
141
  def load_image_on_start():
142
  return np.random.rand(700, 700)
143
+
144
+ def get_text_request(audio_input):
145
+ audio_input_sr, audio_input_np = audio_input
146
+ audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32)
147
+ target_sr = 16000
148
+ resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype)
149
+ resampled_audio_input_t: torch.Tensor = resampler(audio_input_t)
150
+ resampled_audio_input_np = resampled_audio_input_t.numpy()
151
+ # result = asr_pipe(resampled_audio_input_np)
152
+ result = asr_pipe_default(resampled_audio_input_np)
153
+ return result["text"]
154
+
155
+ def get_target_from_request(request_text):
156
+ if 'ONE' in request_text:
157
+ return 1
158
+ if 'TWO' in request_text:
159
+ return 2
160
+ if 'THREE' in request_text:
161
+ return 3
162
+ if 'FOUR' in request_text:
163
+ return 4
164
+ return 'NO TARGET FOUND'
165
+
166
+
167
+ def create_demo():
168
+ # main blocks
169
+ with gr.Blocks() as demo:
170
+ gr.Markdown("## Agent Control with Language")
171
+ gr.Markdown('## Say the agent where to go and what to do')
172
+
173
+ with gr.Row():
174
+ with gr.Column():
175
+ request_audio = gr.Microphone(editable=False)
176
+ # send_btn = gr.Button(value='Send Request')
177
+ request_text = gr.Textbox(label="Request:", lines=2, interactive=False)
178
+ request_target = gr.Textbox(label='Target:', lines=2)
179
+ status = gr.Textbox(label='Plan status:', lines=2)
180
+ with gr.Column():
181
+ output_env = gr.Video(label="Env:", autoplay=True)
182
+
183
+ # EVENTS:
184
+ # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
185
+ # demo.load(fn=load_image_on_start, outputs=output_env_image)
186
+ demo.load(fn=create_standing_animation, outputs=output_env)
187
+ # request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text)
188
+ request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text)
189
+ request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target)
190
+ request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status])
191
+ request_audio.stop_recording(lambda: None, outputs=request_audio)
192
+ return demo
193
+
194
+ # ---------------------------- #
195
+ # main
196
+ # ---------------------------- #
197
+ demo = create_demo()
198
  demo.launch()
draft_1.ipynb ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "id": "bf22c176a849df32",
6
+ "metadata": {
7
+ "ExecuteTime": {
8
+ "end_time": "2025-04-21T06:10:01.065321Z",
9
+ "start_time": "2025-04-21T06:10:01.060267Z"
10
+ }
11
+ },
12
+ "source": [
13
+ "from transformers import pipeline\n",
14
+ "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline, AutoFeatureExtractor\n",
15
+ "import torchaudio\n",
16
+ "import torchaudio.transforms as T"
17
+ ],
18
+ "outputs": [],
19
+ "execution_count": 32
20
+ },
21
+ {
22
+ "metadata": {
23
+ "collapsed": true,
24
+ "ExecuteTime": {
25
+ "end_time": "2025-04-21T05:03:48.582040Z",
26
+ "start_time": "2025-04-21T04:51:46.343821Z"
27
+ }
28
+ },
29
+ "cell_type": "code",
30
+ "outputs": [
31
+ {
32
+ "data": {
33
+ "text/plain": [
34
+ "model.safetensors: 0%| | 0.00/151M [00:00<?, ?B/s]"
35
+ ],
36
+ "application/vnd.jupyter.widget-view+json": {
37
+ "version_major": 2,
38
+ "version_minor": 0,
39
+ "model_id": "51ffb4afb57446278c28d690aa1b22e4"
40
+ }
41
+ },
42
+ "metadata": {},
43
+ "output_type": "display_data"
44
+ },
45
+ {
46
+ "data": {
47
+ "text/plain": [
48
+ "generation_config.json: 0%| | 0.00/3.75k [00:00<?, ?B/s]"
49
+ ],
50
+ "application/vnd.jupyter.widget-view+json": {
51
+ "version_major": 2,
52
+ "version_minor": 0,
53
+ "model_id": "86143c7cd15341e39db1e81231d4fd7e"
54
+ }
55
+ },
56
+ "metadata": {},
57
+ "output_type": "display_data"
58
+ },
59
+ {
60
+ "data": {
61
+ "text/plain": [
62
+ "tokenizer_config.json: 0%| | 0.00/283k [00:00<?, ?B/s]"
63
+ ],
64
+ "application/vnd.jupyter.widget-view+json": {
65
+ "version_major": 2,
66
+ "version_minor": 0,
67
+ "model_id": "01d06664af1c4c169175cd38b00fa78e"
68
+ }
69
+ },
70
+ "metadata": {},
71
+ "output_type": "display_data"
72
+ },
73
+ {
74
+ "data": {
75
+ "text/plain": [
76
+ "vocab.json: 0%| | 0.00/836k [00:00<?, ?B/s]"
77
+ ],
78
+ "application/vnd.jupyter.widget-view+json": {
79
+ "version_major": 2,
80
+ "version_minor": 0,
81
+ "model_id": "8833df76fdf24e92bf51c748aa71bc48"
82
+ }
83
+ },
84
+ "metadata": {},
85
+ "output_type": "display_data"
86
+ },
87
+ {
88
+ "data": {
89
+ "text/plain": [
90
+ "tokenizer.json: 0%| | 0.00/2.48M [00:00<?, ?B/s]"
91
+ ],
92
+ "application/vnd.jupyter.widget-view+json": {
93
+ "version_major": 2,
94
+ "version_minor": 0,
95
+ "model_id": "3f5dfd342c574c2698f42c51a567a77e"
96
+ }
97
+ },
98
+ "metadata": {},
99
+ "output_type": "display_data"
100
+ },
101
+ {
102
+ "data": {
103
+ "text/plain": [
104
+ "merges.txt: 0%| | 0.00/494k [00:00<?, ?B/s]"
105
+ ],
106
+ "application/vnd.jupyter.widget-view+json": {
107
+ "version_major": 2,
108
+ "version_minor": 0,
109
+ "model_id": "e9363f15071d48878ef8230bd6c39177"
110
+ }
111
+ },
112
+ "metadata": {},
113
+ "output_type": "display_data"
114
+ },
115
+ {
116
+ "data": {
117
+ "text/plain": [
118
+ "normalizer.json: 0%| | 0.00/52.7k [00:00<?, ?B/s]"
119
+ ],
120
+ "application/vnd.jupyter.widget-view+json": {
121
+ "version_major": 2,
122
+ "version_minor": 0,
123
+ "model_id": "d0e777a74b9a47f3ad6a18254825122b"
124
+ }
125
+ },
126
+ "metadata": {},
127
+ "output_type": "display_data"
128
+ },
129
+ {
130
+ "data": {
131
+ "text/plain": [
132
+ "added_tokens.json: 0%| | 0.00/34.6k [00:00<?, ?B/s]"
133
+ ],
134
+ "application/vnd.jupyter.widget-view+json": {
135
+ "version_major": 2,
136
+ "version_minor": 0,
137
+ "model_id": "fd44e41876984d81a593d55e07e71be6"
138
+ }
139
+ },
140
+ "metadata": {},
141
+ "output_type": "display_data"
142
+ },
143
+ {
144
+ "data": {
145
+ "text/plain": [
146
+ "special_tokens_map.json: 0%| | 0.00/2.19k [00:00<?, ?B/s]"
147
+ ],
148
+ "application/vnd.jupyter.widget-view+json": {
149
+ "version_major": 2,
150
+ "version_minor": 0,
151
+ "model_id": "c3baafd24e1c4c4d9dcaf5a4715e846e"
152
+ }
153
+ },
154
+ "metadata": {},
155
+ "output_type": "display_data"
156
+ },
157
+ {
158
+ "data": {
159
+ "text/plain": [
160
+ "preprocessor_config.json: 0%| | 0.00/185k [00:00<?, ?B/s]"
161
+ ],
162
+ "application/vnd.jupyter.widget-view+json": {
163
+ "version_major": 2,
164
+ "version_minor": 0,
165
+ "model_id": "3a671889c0504b50bcff2aec93497d78"
166
+ }
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "name": "stderr",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "Device set to use mps:0\n"
176
+ ]
177
+ }
178
+ ],
179
+ "execution_count": 4,
180
+ "source": [
181
+ "\n",
182
+ "pipe = pipeline(model=\"openai/whisper-tiny\", task=\"automatic-speech-recognition\")\n"
183
+ ],
184
+ "id": "initial_id"
185
+ },
186
+ {
187
+ "metadata": {},
188
+ "cell_type": "code",
189
+ "outputs": [],
190
+ "execution_count": null,
191
+ "source": [
192
+ "# Load audio file\n",
193
+ "waveform_1, sample_rate = torchaudio.load(\"sample.wav\")\n",
194
+ "# Target sampling rate (e.g., 16000 Hz for Whisper)\n",
195
+ "target_sr = 16000\n",
196
+ "\n",
197
+ "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
198
+ "waveform = resampler(waveform_1)\n",
199
+ "waveform_np = waveform.squeeze().numpy()\n",
200
+ "\n",
201
+ "print(waveform.shape) # (channels, samples) — usually (1, N)\n",
202
+ "print(sample_rate)\n",
203
+ "print(waveform_np)"
204
+ ],
205
+ "id": "dc202f529230fa87"
206
+ },
207
+ {
208
+ "metadata": {
209
+ "ExecuteTime": {
210
+ "end_time": "2025-04-21T05:08:38.144954Z",
211
+ "start_time": "2025-04-21T05:08:38.087644Z"
212
+ }
213
+ },
214
+ "cell_type": "code",
215
+ "source": [
216
+ "save_dir = \"./models_for_proj/whisper-tiny\"\n",
217
+ "device = 'cpu'\n",
218
+ "pipe.generation_config.save_pretrained(save_dir)\n",
219
+ "pipe.tokenizer.save_pretrained(save_dir)\n",
220
+ "pipe.feature_extractor.save_pretrained(save_dir)\n"
221
+ ],
222
+ "id": "ed09605af0b78939",
223
+ "outputs": [
224
+ {
225
+ "data": {
226
+ "text/plain": [
227
+ "['./models_for_proj/whisper-tiny/preprocessor_config.json']"
228
+ ]
229
+ },
230
+ "execution_count": 6,
231
+ "metadata": {},
232
+ "output_type": "execute_result"
233
+ }
234
+ ],
235
+ "execution_count": 6
236
+ },
237
+ {
238
+ "metadata": {
239
+ "ExecuteTime": {
240
+ "end_time": "2025-04-21T05:35:59.540770Z",
241
+ "start_time": "2025-04-21T05:35:59.476164Z"
242
+ }
243
+ },
244
+ "cell_type": "code",
245
+ "source": [
246
+ "\n",
247
+ "# model = AutoModelForSpeechSeq2Seq.from_pretrained(save_dir, device=device)\n",
248
+ "# model.config.forced_decoder_ids = None\n",
249
+ "# processor = AutoProcessor.from_pretrained(save_dir, device=device)\n",
250
+ "# tokenizer = AutoTokenizer.from_pretrained(save_dir, device=device)\n",
251
+ "# feature_extractor = AutoFeatureExtractor.from_pretrained(save_dir, device=device)\n",
252
+ "# pipe = pipeline(\"automatic-speech-recognition\", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
253
+ "# result = pipe(\"sample.wav\")\n",
254
+ "# result[\"text\"]"
255
+ ],
256
+ "id": "1dcd38e5ca08781b",
257
+ "outputs": [
258
+ {
259
+ "ename": "TypeError",
260
+ "evalue": "WhisperForConditionalGeneration.__init__() got an unexpected keyword argument 'device'",
261
+ "output_type": "error",
262
+ "traceback": [
263
+ "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
264
+ "\u001B[31mTypeError\u001B[39m Traceback (most recent call last)",
265
+ "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[30]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtransformers\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer, pipeline, AutoFeatureExtractor\n\u001B[32m 2\u001B[39m device = \u001B[33m'\u001B[39m\u001B[33mcpu\u001B[39m\u001B[33m'\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m model = \u001B[43mAutoModelForSpeechSeq2Seq\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\u001B[43msave_dir\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4\u001B[39m model.config.forced_decoder_ids = \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[32m 5\u001B[39m processor = AutoProcessor.from_pretrained(save_dir, device=device)\n",
266
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py:573\u001B[39m, in \u001B[36m_BaseAutoModelClass.from_pretrained\u001B[39m\u001B[34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001B[39m\n\u001B[32m 571\u001B[39m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;28mtype\u001B[39m(config) \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mcls\u001B[39m._model_mapping.keys():\n\u001B[32m 572\u001B[39m model_class = _get_model_class(config, \u001B[38;5;28mcls\u001B[39m._model_mapping)\n\u001B[32m--> \u001B[39m\u001B[32m573\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mmodel_class\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfrom_pretrained\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 574\u001B[39m \u001B[43m \u001B[49m\u001B[43mpretrained_model_name_or_path\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43mmodel_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m=\u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mhub_kwargs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\n\u001B[32m 575\u001B[39m \u001B[43m \u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 576\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 577\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33mUnrecognized configuration class \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mconfig.\u001B[34m__class__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m for this kind of AutoModel: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mcls\u001B[39m.\u001B[34m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m.\u001B[39m\u001B[38;5;130;01m\\n\u001B[39;00m\u001B[33m\"\u001B[39m\n\u001B[32m 578\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33mModel type should be one of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[33m'\u001B[39m\u001B[33m, \u001B[39m\u001B[33m'\u001B[39m.join(c.\u001B[34m__name__\u001B[39m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mfor\u001B[39;00m\u001B[38;5;250m \u001B[39mc\u001B[38;5;250m \u001B[39m\u001B[38;5;129;01min\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28mcls\u001B[39m._model_mapping.keys())\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m.\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 579\u001B[39m )\n",
267
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:272\u001B[39m, in \u001B[36mrestore_default_torch_dtype.<locals>._wrapper\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 270\u001B[39m old_dtype = torch.get_default_dtype()\n\u001B[32m 271\u001B[39m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[32m--> \u001B[39m\u001B[32m272\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 273\u001B[39m \u001B[38;5;28;01mfinally\u001B[39;00m:\n\u001B[32m 274\u001B[39m torch.set_default_dtype(old_dtype)\n",
268
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:4401\u001B[39m, in \u001B[36mPreTrainedModel.from_pretrained\u001B[39m\u001B[34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001B[39m\n\u001B[32m 4395\u001B[39m config = \u001B[38;5;28mcls\u001B[39m._autoset_attn_implementation(\n\u001B[32m 4396\u001B[39m config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map\n\u001B[32m 4397\u001B[39m )\n\u001B[32m 4399\u001B[39m \u001B[38;5;28;01mwith\u001B[39;00m ContextManagers(model_init_context):\n\u001B[32m 4400\u001B[39m \u001B[38;5;66;03m# Let's make sure we don't run the init function of buffer modules\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m4401\u001B[39m model = \u001B[38;5;28;43mcls\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mconfig\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43mmodel_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4403\u001B[39m \u001B[38;5;66;03m# Make sure to tie the weights correctly\u001B[39;00m\n\u001B[32m 4404\u001B[39m model.tie_weights()\n",
269
+ "\u001B[31mTypeError\u001B[39m: WhisperForConditionalGeneration.__init__() got an unexpected keyword argument 'device'"
270
+ ]
271
+ }
272
+ ],
273
+ "execution_count": 30
274
+ },
275
+ {
276
+ "metadata": {
277
+ "ExecuteTime": {
278
+ "end_time": "2025-04-21T06:13:00.420733Z",
279
+ "start_time": "2025-04-21T06:13:00.033330Z"
280
+ }
281
+ },
282
+ "cell_type": "code",
283
+ "source": [
284
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
285
+ "# load dummy dataset and read audio files\n",
286
+ "\n",
287
+ "# input\n",
288
+ "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
289
+ "target_sr = 16000\n",
290
+ "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
291
+ "waveform = resampler(waveform)\n",
292
+ "waveform_np = waveform.squeeze().numpy()\n",
293
+ "\n",
294
+ "\n",
295
+ "processor = WhisperProcessor.from_pretrained(save_dir)\n",
296
+ "model = WhisperForConditionalGeneration.from_pretrained(save_dir)\n",
297
+ "model.config.forced_decoder_ids = None\n",
298
+ "\n",
299
+ "input_features = processor(waveform_np, sampling_rate=target_sr, return_tensors=\"pt\", device=device).input_features\n",
300
+ "\n",
301
+ "# generate token ids\n",
302
+ "predicted_ids = model.generate(input_features)\n",
303
+ "# decode token ids to text\n",
304
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
305
+ "# ['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']\n",
306
+ "print(transcription)\n",
307
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
308
+ "# [' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']\n",
309
+ "print(transcription)"
310
+ ],
311
+ "id": "b0865456fed26d31",
312
+ "outputs": [
313
+ {
314
+ "name": "stderr",
315
+ "output_type": "stream",
316
+ "text": [
317
+ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n"
318
+ ]
319
+ },
320
+ {
321
+ "ename": "ValueError",
322
+ "evalue": "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively.",
323
+ "output_type": "error",
324
+ "traceback": [
325
+ "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
326
+ "\u001B[31mValueError\u001B[39m Traceback (most recent call last)",
327
+ "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[34]\u001B[39m\u001B[32m, line 19\u001B[39m\n\u001B[32m 16\u001B[39m input_features = processor(waveform_np, sampling_rate=target_sr, return_tensors=\u001B[33m\"\u001B[39m\u001B[33mpt\u001B[39m\u001B[33m\"\u001B[39m, device=device).input_features\n\u001B[32m 18\u001B[39m \u001B[38;5;66;03m# generate token ids\u001B[39;00m\n\u001B[32m---> \u001B[39m\u001B[32m19\u001B[39m predicted_ids = \u001B[43mmodel\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgenerate\u001B[49m\u001B[43m(\u001B[49m\u001B[43minput_features\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 20\u001B[39m \u001B[38;5;66;03m# decode token ids to text\u001B[39;00m\n\u001B[32m 21\u001B[39m transcription = processor.batch_decode(predicted_ids, skip_special_tokens=\u001B[38;5;28;01mFalse\u001B[39;00m)\n",
328
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/generation_whisper.py:774\u001B[39m, in \u001B[36mWhisperGenerationMixin.generate\u001B[39m\u001B[34m(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, time_precision_features, return_token_timestamps, return_segments, return_dict_in_generate, force_unique_generate_call, **kwargs)\u001B[39m\n\u001B[32m 765\u001B[39m proc.set_begin_index(decoder_input_ids.shape[-\u001B[32m1\u001B[39m])\n\u001B[32m 767\u001B[39m \u001B[38;5;66;03m# 6.6 Run generate with fallback\u001B[39;00m\n\u001B[32m 768\u001B[39m (\n\u001B[32m 769\u001B[39m seek_sequences,\n\u001B[32m 770\u001B[39m seek_outputs,\n\u001B[32m 771\u001B[39m should_skip,\n\u001B[32m 772\u001B[39m do_condition_on_prev_tokens,\n\u001B[32m 773\u001B[39m model_output_type,\n\u001B[32m--> \u001B[39m\u001B[32m774\u001B[39m ) = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mgenerate_with_fallback\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 775\u001B[39m \u001B[43m \u001B[49m\u001B[43msegment_input\u001B[49m\u001B[43m=\u001B[49m\u001B[43msegment_input\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 776\u001B[39m \u001B[43m \u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 777\u001B[39m \u001B[43m \u001B[49m\u001B[43mcur_bsz\u001B[49m\u001B[43m=\u001B[49m\u001B[43mcur_bsz\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 778\u001B[39m \u001B[43m \u001B[49m\u001B[43mbatch_idx_map\u001B[49m\u001B[43m=\u001B[49m\u001B[43mbatch_idx_map\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 779\u001B[39m \u001B[43m \u001B[49m\u001B[43mseek\u001B[49m\u001B[43m=\u001B[49m\u001B[43mseek\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 780\u001B[39m \u001B[43m \u001B[49m\u001B[43mnum_segment_frames\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnum_segment_frames\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 781\u001B[39m \u001B[43m \u001B[49m\u001B[43mmax_frames\u001B[49m\u001B[43m=\u001B[49m\u001B[43mmax_frames\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 782\u001B[39m \u001B[43m \u001B[49m\u001B[43mtemperatures\u001B[49m\u001B[43m=\u001B[49m\u001B[43mtemperatures\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 783\u001B[39m \u001B[43m \u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m=\u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 784\u001B[39m \u001B[43m \u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m=\u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 785\u001B[39m \u001B[43m \u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m=\u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 786\u001B[39m \u001B[43m \u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m=\u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 787\u001B[39m \u001B[43m \u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m=\u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 788\u001B[39m \u001B[43m \u001B[49m\u001B[43mreturn_token_timestamps\u001B[49m\u001B[43m=\u001B[49m\u001B[43mreturn_token_timestamps\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 789\u001B[39m \u001B[43m \u001B[49m\u001B[43mdo_condition_on_prev_tokens\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdo_condition_on_prev_tokens\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 790\u001B[39m \u001B[43m \u001B[49m\u001B[43mis_shortform\u001B[49m\u001B[43m=\u001B[49m\u001B[43mis_shortform\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 791\u001B[39m \u001B[43m \u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[43m=\u001B[49m\u001B[43mbatch_size\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 792\u001B[39m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m=\u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 793\u001B[39m \u001B[43m \u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m=\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 794\u001B[39m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 796\u001B[39m \u001B[38;5;66;03m# 6.7 In every generated sequence, split by timestamp tokens and extract segments\u001B[39;00m\n\u001B[32m 797\u001B[39m \u001B[38;5;28;01mfor\u001B[39;00m i, seek_sequence \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28menumerate\u001B[39m(seek_sequences):\n",
329
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/generation_whisper.py:950\u001B[39m, in \u001B[36mWhisperGenerationMixin.generate_with_fallback\u001B[39m\u001B[34m(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, is_shortform, batch_size, attention_mask, kwargs)\u001B[39m\n\u001B[32m 945\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m generate_kwargs.get(\u001B[33m\"\u001B[39m\u001B[33mencoder_outputs\u001B[39m\u001B[33m\"\u001B[39m) \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 946\u001B[39m generate_kwargs[\u001B[33m\"\u001B[39m\u001B[33mencoder_outputs\u001B[39m\u001B[33m\"\u001B[39m] = F.pad(\n\u001B[32m 947\u001B[39m generate_kwargs[\u001B[33m\"\u001B[39m\u001B[33mencoder_outputs\u001B[39m\u001B[33m\"\u001B[39m], (\u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, \u001B[32m0\u001B[39m, batch_size - cur_bsz), value=\u001B[32m0\u001B[39m\n\u001B[32m 948\u001B[39m )\n\u001B[32m--> \u001B[39m\u001B[32m950\u001B[39m seek_outputs = \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m.\u001B[49m\u001B[43mgenerate\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 951\u001B[39m \u001B[43m \u001B[49m\u001B[43msegment_input\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 952\u001B[39m \u001B[43m \u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m=\u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 953\u001B[39m \u001B[43m \u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m=\u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 954\u001B[39m \u001B[43m \u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m=\u001B[49m\u001B[43mstopping_criteria\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 955\u001B[39m \u001B[43m \u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m=\u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 956\u001B[39m \u001B[43m \u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m=\u001B[49m\u001B[43msynced_gpus\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 957\u001B[39m \u001B[43m \u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43mdecoder_input_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 958\u001B[39m \u001B[43m \u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m=\u001B[49m\u001B[43mattention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 959\u001B[39m \u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mgenerate_kwargs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 960\u001B[39m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 962\u001B[39m model_output_type = \u001B[38;5;28mtype\u001B[39m(seek_outputs)\n\u001B[32m 964\u001B[39m \u001B[38;5;66;03m# post-process sequence tokens and outputs to be in list form\u001B[39;00m\n",
330
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001B[39m, in \u001B[36mcontext_decorator.<locals>.decorate_context\u001B[39m\u001B[34m(*args, **kwargs)\u001B[39m\n\u001B[32m 113\u001B[39m \u001B[38;5;129m@functools\u001B[39m.wraps(func)\n\u001B[32m 114\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mdecorate_context\u001B[39m(*args, **kwargs):\n\u001B[32m 115\u001B[39m \u001B[38;5;28;01mwith\u001B[39;00m ctx_factory():\n\u001B[32m--> \u001B[39m\u001B[32m116\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mfunc\u001B[49m\u001B[43m(\u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
331
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:2219\u001B[39m, in \u001B[36mGenerationMixin.generate\u001B[39m\u001B[34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, use_model_defaults, **kwargs)\u001B[39m\n\u001B[32m 2208\u001B[39m warnings.warn(\n\u001B[32m 2209\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mYou are calling .generate() with the `input_ids` being on a device type different\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 2210\u001B[39m \u001B[33mf\u001B[39m\u001B[33m\"\u001B[39m\u001B[33m than your model\u001B[39m\u001B[33m'\u001B[39m\u001B[33ms device. `input_ids` is on \u001B[39m\u001B[38;5;132;01m{\u001B[39;00minput_ids.device.type\u001B[38;5;132;01m}\u001B[39;00m\u001B[33m, whereas the model\u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m (...)\u001B[39m\u001B[32m 2215\u001B[39m \u001B[38;5;167;01mUserWarning\u001B[39;00m,\n\u001B[32m 2216\u001B[39m )\n\u001B[32m 2218\u001B[39m \u001B[38;5;66;03m# 9. prepare logits processors and stopping criteria\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m2219\u001B[39m prepared_logits_processor = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43m_get_logits_processor\u001B[49m\u001B[43m(\u001B[49m\n\u001B[32m 2220\u001B[39m \u001B[43m \u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m=\u001B[49m\u001B[43mgeneration_config\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2221\u001B[39m \u001B[43m \u001B[49m\u001B[43minput_ids_seq_length\u001B[49m\u001B[43m=\u001B[49m\u001B[43minput_ids_length\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2222\u001B[39m \u001B[43m \u001B[49m\u001B[43mencoder_input_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43minputs_tensor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2223\u001B[39m \u001B[43m \u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m=\u001B[49m\u001B[43mprefix_allowed_tokens_fn\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2224\u001B[39m \u001B[43m \u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m=\u001B[49m\u001B[43mlogits_processor\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2225\u001B[39m \u001B[43m \u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m=\u001B[49m\u001B[43minputs_tensor\u001B[49m\u001B[43m.\u001B[49m\u001B[43mdevice\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2226\u001B[39m \u001B[43m \u001B[49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m=\u001B[49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2227\u001B[39m \u001B[43m \u001B[49m\u001B[43mnegative_prompt_ids\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnegative_prompt_ids\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2228\u001B[39m \u001B[43m \u001B[49m\u001B[43mnegative_prompt_attention_mask\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnegative_prompt_attention_mask\u001B[49m\u001B[43m,\u001B[49m\n\u001B[32m 2229\u001B[39m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 2230\u001B[39m prepared_stopping_criteria = \u001B[38;5;28mself\u001B[39m._get_stopping_criteria(\n\u001B[32m 2231\u001B[39m generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs\n\u001B[32m 2232\u001B[39m )\n\u001B[32m 2234\u001B[39m \u001B[38;5;66;03m# Set model_kwargs `use_cache` so we can use it later in forward runs\u001B[39;00m\n",
332
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/generation/utils.py:1083\u001B[39m, in \u001B[36mGenerationMixin._get_logits_processor\u001B[39m\u001B[34m(self, generation_config, input_ids_seq_length, encoder_input_ids, prefix_allowed_tokens_fn, logits_processor, device, model_kwargs, negative_prompt_ids, negative_prompt_attention_mask)\u001B[39m\n\u001B[32m 1074\u001B[39m processors.append(\n\u001B[32m 1075\u001B[39m SuppressTokensAtBeginLogitsProcessor(\n\u001B[32m 1076\u001B[39m generation_config.begin_suppress_tokens,\n\u001B[32m (...)\u001B[39m\u001B[32m 1079\u001B[39m )\n\u001B[32m 1080\u001B[39m )\n\u001B[32m 1081\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m generation_config.forced_decoder_ids \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 1082\u001B[39m \u001B[38;5;66;03m# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT\u001B[39;00m\n\u001B[32m-> \u001B[39m\u001B[32m1083\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[32m 1084\u001B[39m \u001B[33m\"\u001B[39m\u001B[33mYou have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument \u001B[39m\u001B[33m\"\u001B[39m\n\u001B[32m 1085\u001B[39m \u001B[33m\"\u001B[39m\u001B[33min favour of `input_ids` or `decoder_input_ids` respectively.\u001B[39m\u001B[33m\"\u001B[39m,\n\u001B[32m 1086\u001B[39m )\n\u001B[32m 1088\u001B[39m \u001B[38;5;66;03m# TODO (joao): find a strategy to specify the order of the processors\u001B[39;00m\n\u001B[32m 1089\u001B[39m processors = \u001B[38;5;28mself\u001B[39m._merge_criteria_processor_list(processors, logits_processor)\n",
333
+ "\u001B[31mValueError\u001B[39m: You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument in favour of `input_ids` or `decoder_input_ids` respectively."
334
+ ]
335
+ }
336
+ ],
337
+ "execution_count": 34
338
+ },
339
+ {
340
+ "metadata": {
341
+ "ExecuteTime": {
342
+ "end_time": "2025-04-21T06:15:41.079099Z",
343
+ "start_time": "2025-04-21T06:15:37.277194Z"
344
+ }
345
+ },
346
+ "cell_type": "code",
347
+ "source": [
348
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
349
+ "from datasets import load_dataset\n",
350
+ "\n",
351
+ "# load model and processor\n",
352
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
353
+ "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
354
+ "model.config.forced_decoder_ids = None\n",
355
+ "\n",
356
+ "# load dummy dataset and read audio files\n",
357
+ "ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
358
+ "sample = ds[0][\"audio\"]\n",
359
+ "input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features\n",
360
+ "\n",
361
+ "# generate token ids\n",
362
+ "predicted_ids = model.generate(input_features)\n",
363
+ "# decode token ids to text\n",
364
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
365
+ "processor(transcription)\n",
366
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
367
+ "processor(transcription)\n"
368
+ ],
369
+ "id": "b4137e08d1a516e5",
370
+ "outputs": [
371
+ {
372
+ "name": "stderr",
373
+ "output_type": "stream",
374
+ "text": [
375
+ "It is strongly recommended to pass the `sampling_rate` argument to `WhisperFeatureExtractor()`. Failing to do so can result in silent errors that might be hard to debug.\n"
376
+ ]
377
+ },
378
+ {
379
+ "ename": "ValueError",
380
+ "evalue": "could not convert string to float: ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'",
381
+ "output_type": "error",
382
+ "traceback": [
383
+ "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
384
+ "\u001B[31mValueError\u001B[39m Traceback (most recent call last)",
385
+ "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[37]\u001B[39m\u001B[32m, line 18\u001B[39m\n\u001B[32m 16\u001B[39m \u001B[38;5;66;03m# decode token ids to text\u001B[39;00m\n\u001B[32m 17\u001B[39m transcription = processor.batch_decode(predicted_ids, skip_special_tokens=\u001B[38;5;28;01mFalse\u001B[39;00m)\n\u001B[32m---> \u001B[39m\u001B[32m18\u001B[39m \u001B[43mprocessor\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtranscription\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 19\u001B[39m transcription = processor.batch_decode(predicted_ids, skip_special_tokens=\u001B[38;5;28;01mTrue\u001B[39;00m)\n\u001B[32m 20\u001B[39m processor(transcription)\n",
386
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/processing_whisper.py:69\u001B[39m, in \u001B[36mWhisperProcessor.__call__\u001B[39m\u001B[34m(self, *args, **kwargs)\u001B[39m\n\u001B[32m 66\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[33m\"\u001B[39m\u001B[33mYou need to specify either an `audio` or `text` input to process.\u001B[39m\u001B[33m\"\u001B[39m)\n\u001B[32m 68\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m audio \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m---> \u001B[39m\u001B[32m69\u001B[39m inputs = \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mfeature_extractor\u001B[49m\u001B[43m(\u001B[49m\u001B[43maudio\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msampling_rate\u001B[49m\u001B[43m=\u001B[49m\u001B[43msampling_rate\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m*\u001B[49m\u001B[43m*\u001B[49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 70\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m text \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[32m 71\u001B[39m encodings = \u001B[38;5;28mself\u001B[39m.tokenizer(text, **kwargs)\n",
387
+ "\u001B[36mFile \u001B[39m\u001B[32m~/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/feature_extraction_whisper.py:281\u001B[39m, in \u001B[36mWhisperFeatureExtractor.__call__\u001B[39m\u001B[34m(self, raw_speech, truncation, pad_to_multiple_of, return_tensors, return_attention_mask, padding, max_length, sampling_rate, do_normalize, device, return_token_timestamps, **kwargs)\u001B[39m\n\u001B[32m 279\u001B[39m raw_speech = [np.asarray([speech], dtype=np.float32).T \u001B[38;5;28;01mfor\u001B[39;00m speech \u001B[38;5;129;01min\u001B[39;00m raw_speech]\n\u001B[32m 280\u001B[39m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m is_batched \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(raw_speech, np.ndarray):\n\u001B[32m--> \u001B[39m\u001B[32m281\u001B[39m raw_speech = \u001B[43mnp\u001B[49m\u001B[43m.\u001B[49m\u001B[43masarray\u001B[49m\u001B[43m(\u001B[49m\u001B[43mraw_speech\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdtype\u001B[49m\u001B[43m=\u001B[49m\u001B[43mnp\u001B[49m\u001B[43m.\u001B[49m\u001B[43mfloat32\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 282\u001B[39m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(raw_speech, np.ndarray) \u001B[38;5;129;01mand\u001B[39;00m raw_speech.dtype \u001B[38;5;129;01mis\u001B[39;00m np.dtype(np.float64):\n\u001B[32m 283\u001B[39m raw_speech = raw_speech.astype(np.float32)\n",
388
+ "\u001B[31mValueError\u001B[39m: could not convert string to float: ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'"
389
+ ]
390
+ }
391
+ ],
392
+ "execution_count": 37
393
+ },
394
+ {
395
+ "metadata": {
396
+ "ExecuteTime": {
397
+ "end_time": "2025-04-21T05:31:26.352787Z",
398
+ "start_time": "2025-04-21T05:31:26.343398Z"
399
+ }
400
+ },
401
+ "cell_type": "code",
402
+ "source": "",
403
+ "id": "37fa63b1c22f4a69",
404
+ "outputs": [
405
+ {
406
+ "name": "stdout",
407
+ "output_type": "stream",
408
+ "text": [
409
+ "torch.Size([1, 24192])\n",
410
+ "24000\n",
411
+ "[ 0.0000000e+00 0.0000000e+00 0.0000000e+00 ... -1.3932839e-05\n",
412
+ " -3.6663318e-05 -1.3932839e-05]\n"
413
+ ]
414
+ }
415
+ ],
416
+ "execution_count": 25
417
+ },
418
+ {
419
+ "metadata": {
420
+ "ExecuteTime": {
421
+ "end_time": "2025-04-21T06:28:40.294060Z",
422
+ "start_time": "2025-04-21T06:28:35.493462Z"
423
+ }
424
+ },
425
+ "cell_type": "code",
426
+ "source": [
427
+ "import torch\n",
428
+ "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n",
429
+ "from datasets import load_dataset\n",
430
+ "\n",
431
+ "\n",
432
+ "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
433
+ "torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32\n",
434
+ "\n",
435
+ "# model_id = \"distil-whisper/distil-small.en\"\n",
436
+ "model_id = \"./models_for_proj/librispeech_asr_dummy\"\n",
437
+ "\n",
438
+ "model = AutoModelForSpeechSeq2Seq.from_pretrained(\n",
439
+ " model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True\n",
440
+ ")\n",
441
+ "model.to(device)\n",
442
+ "\n",
443
+ "processor = AutoProcessor.from_pretrained(model_id)\n",
444
+ "\n",
445
+ "pipe = pipeline(\n",
446
+ " \"automatic-speech-recognition\",\n",
447
+ " model=model,\n",
448
+ " tokenizer=processor.tokenizer,\n",
449
+ " feature_extractor=processor.feature_extractor,\n",
450
+ " max_new_tokens=128,\n",
451
+ " torch_dtype=torch_dtype,\n",
452
+ " device=device,\n",
453
+ ")\n",
454
+ "\n",
455
+ "# dataset = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
456
+ "# sample = dataset[0][\"audio\"]\n",
457
+ "# result = pipe(sample)\n",
458
+ "\n",
459
+ "# input\n",
460
+ "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
461
+ "target_sr = 16000\n",
462
+ "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
463
+ "waveform = resampler(waveform)\n",
464
+ "waveform_np = waveform.squeeze().numpy()\n",
465
+ "# sample = dataset[2][\"audio\"]\n",
466
+ "\n",
467
+ "# result = pipe(sample)\n",
468
+ "result = pipe(waveform_np)\n",
469
+ "print(result[\"text\"])\n"
470
+ ],
471
+ "id": "e7f0a5bccb4e204f",
472
+ "outputs": [
473
+ {
474
+ "name": "stderr",
475
+ "output_type": "stream",
476
+ "text": [
477
+ "Device set to use cpu\n",
478
+ "/Users/perchik/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/pipelines/automatic_speech_recognition.py:312: FutureWarning: `max_new_tokens` is deprecated and will be removed in version 4.49 of Transformers. To remove this warning, pass `max_new_tokens` as a key inside `generate_kwargs` instead.\n",
479
+ " warnings.warn(\n",
480
+ "/Users/perchik/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/models/whisper/generation_whisper.py:573: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.\n",
481
+ " warnings.warn(\n",
482
+ "`generation_config` default values have been modified to match model-specific defaults: {'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361], 'begin_suppress_tokens': [220, 50256]}. If this is not desired, please set these values explicitly.\n",
483
+ "A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensLogitsProcessor'> to see related `.generate()` flags.\n",
484
+ "A custom logits processor of type <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> has been passed to `.generate()`, but it was also created in `.generate()`, given its parameterization. The custom <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> will take precedence. Please check the docstring of <class 'transformers.generation.logits_process.SuppressTokensAtBeginLogitsProcessor'> to see related `.generate()` flags.\n"
485
+ ]
486
+ },
487
+ {
488
+ "name": "stdout",
489
+ "output_type": "stream",
490
+ "text": [
491
+ " Mr. Quilter is the Apostle of the Middle Classes, and we are glad to welcome his Gospel.\n"
492
+ ]
493
+ }
494
+ ],
495
+ "execution_count": 46
496
+ },
497
+ {
498
+ "metadata": {
499
+ "ExecuteTime": {
500
+ "end_time": "2025-04-21T06:27:16.239153Z",
501
+ "start_time": "2025-04-21T06:27:15.587609Z"
502
+ }
503
+ },
504
+ "cell_type": "code",
505
+ "source": [
506
+ "save_dir = \"./models_for_proj/librispeech_asr_dummy\"\n",
507
+ "pipe.model.save_pretrained(save_dir)\n",
508
+ "pipe.tokenizer.save_pretrained(save_dir)\n",
509
+ "pipe.feature_extractor.save_pretrained(save_dir)"
510
+ ],
511
+ "id": "81b57090829a7294",
512
+ "outputs": [
513
+ {
514
+ "name": "stderr",
515
+ "output_type": "stream",
516
+ "text": [
517
+ "/Users/perchik/PycharmProjects/Learning_LLMs/.venv/lib/python3.12/site-packages/transformers/modeling_utils.py:3353: UserWarning: Moving the following attributes in the config to the generation config: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361]}. You are seeing this warning because you've set generation parameters in the model config, as opposed to in the generation config.\n",
518
+ " warnings.warn(\n"
519
+ ]
520
+ },
521
+ {
522
+ "data": {
523
+ "text/plain": [
524
+ "['./models_for_proj/librispeech_asr_dummy/preprocessor_config.json']"
525
+ ]
526
+ },
527
+ "execution_count": 45,
528
+ "metadata": {},
529
+ "output_type": "execute_result"
530
+ }
531
+ ],
532
+ "execution_count": 45
533
+ },
534
+ {
535
+ "metadata": {
536
+ "ExecuteTime": {
537
+ "end_time": "2025-04-21T05:31:45.237137Z",
538
+ "start_time": "2025-04-21T05:31:45.234474Z"
539
+ }
540
+ },
541
+ "cell_type": "code",
542
+ "source": "target_sr",
543
+ "id": "61b31c4b81fd098f",
544
+ "outputs": [
545
+ {
546
+ "data": {
547
+ "text/plain": [
548
+ "16000"
549
+ ]
550
+ },
551
+ "execution_count": 26,
552
+ "metadata": {},
553
+ "output_type": "execute_result"
554
+ }
555
+ ],
556
+ "execution_count": 26
557
+ },
558
+ {
559
+ "metadata": {
560
+ "ExecuteTime": {
561
+ "end_time": "2025-04-21T11:20:26.931270Z",
562
+ "start_time": "2025-04-21T11:20:24.762498Z"
563
+ }
564
+ },
565
+ "cell_type": "code",
566
+ "source": [
567
+ "# input\n",
568
+ "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
569
+ "target_sr = 16000\n",
570
+ "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
571
+ "waveform = resampler(waveform)\n",
572
+ "waveform_np = waveform.squeeze().numpy()\n",
573
+ "# sample = dataset[2][\"audio\"]\n",
574
+ "\n",
575
+ "# result = pipe(sample)\n",
576
+ "result = pipe(waveform_np)\n",
577
+ "print(result[\"text\"])"
578
+ ],
579
+ "id": "5c9f9ff839e346f8",
580
+ "outputs": [
581
+ {
582
+ "name": "stdout",
583
+ "output_type": "stream",
584
+ "text": [
585
+ " This is a simple text.\n"
586
+ ]
587
+ }
588
+ ],
589
+ "execution_count": 48
590
+ },
591
+ {
592
+ "metadata": {
593
+ "ExecuteTime": {
594
+ "end_time": "2025-04-21T11:54:49.800197Z",
595
+ "start_time": "2025-04-21T11:54:47.143900Z"
596
+ }
597
+ },
598
+ "cell_type": "code",
599
+ "source": [
600
+ "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC\n",
601
+ "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
602
+ "model = Wav2Vec2ForCTC.from_pretrained(\"facebook/wav2vec2-base-960h\")"
603
+ ],
604
+ "id": "a7084d040f38e0f5",
605
+ "outputs": [
606
+ {
607
+ "name": "stderr",
608
+ "output_type": "stream",
609
+ "text": [
610
+ "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
611
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
612
+ ]
613
+ }
614
+ ],
615
+ "execution_count": 49
616
+ },
617
+ {
618
+ "metadata": {},
619
+ "cell_type": "code",
620
+ "outputs": [],
621
+ "execution_count": null,
622
+ "source": "",
623
+ "id": "f886807e783c9532"
624
+ }
625
+ ],
626
+ "metadata": {
627
+ "kernelspec": {
628
+ "display_name": "Python 3",
629
+ "language": "python",
630
+ "name": "python3"
631
+ },
632
+ "language_info": {
633
+ "codemirror_mode": {
634
+ "name": "ipython",
635
+ "version": 2
636
+ },
637
+ "file_extension": ".py",
638
+ "mimetype": "text/x-python",
639
+ "name": "python",
640
+ "nbconvert_exporter": "python",
641
+ "pygments_lexer": "ipython2",
642
+ "version": "2.7.6"
643
+ }
644
+ },
645
+ "nbformat": 4,
646
+ "nbformat_minor": 5
647
+ }
draft_2.py DELETED
@@ -1,27 +0,0 @@
1
- import gymnasium as gym
2
-
3
- from stable_baselines3 import PPO
4
- from stable_baselines3.common.env_util import make_vec_env
5
- import torch
6
-
7
- # Parallel environments
8
- vec_env = make_vec_env("CartPole-v1", n_envs=4)
9
-
10
- policy_kwargs = dict(activation_fn=torch.nn.ReLU,
11
- net_arch=dict(pi=[32, 32], vf=[32, 32]))
12
- model = PPO("MlpPolicy", vec_env,
13
- verbose=1,
14
- policy_kwargs=policy_kwargs,
15
- tensorboard_log="./ppo_tensorboard/")
16
- model.learn(total_timesteps=100000, tb_log_name="CartPole")
17
- model.save("ppo_cartpole")
18
-
19
- del model # remove to demonstrate saving and loading
20
-
21
- model = PPO.load("ppo_cartpole")
22
-
23
- obs = vec_env.reset()
24
- while True:
25
- action, _states = model.predict(obs)
26
- obs, rewards, dones, info = vec_env.step(action)
27
- vec_env.render("human")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
draft_animation.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import matplotlib.animation as animation
4
+ import tempfile
5
+ from matplotlib.patches import Circle
6
+
7
+
8
+ def create_dummy_animation():
9
+ fig, ax = plt.subplots(figsize=(7, 7))
10
+ xdata, ydata = [], []
11
+ ln, = plt.plot([], [], 'b-', animated=True)
12
+
13
+ def init():
14
+ ax.set_xlim(0, 2*np.pi)
15
+ ax.set_ylim(-1.1, 1.1)
16
+ return ln,
17
+
18
+ def update(frame):
19
+ xdata.append(frame)
20
+ ydata.append(np.sin(frame))
21
+ ln.set_data(xdata, ydata)
22
+ return ln,
23
+
24
+ ani = animation.FuncAnimation(
25
+ fig, update, frames=np.linspace(0, 2*np.pi, 100),
26
+ init_func=init, blit=True, repeat=False
27
+ )
28
+
29
+ # Save to MP4
30
+ temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
31
+ ani.save(temp_video.name, writer='ffmpeg', fps=20)
32
+ plt.close(fig)
33
+
34
+ return temp_video.name
35
+
36
+
37
+ def create_animation():
38
+ path = [(i,i) for i in range(50)]
39
+ targets_x = [20, 80, 80, 20]
40
+ targets_y = [20, 20, 80, 80]
41
+ RADIUS_COVERAGE = 10
42
+ fig, ax = plt.subplots(figsize=(7, 7))
43
+
44
+ # agent
45
+ ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15)
46
+
47
+ # targets
48
+ ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15)
49
+ for t_x, t_y in zip(targets_x, targets_y):
50
+ circle = Circle((t_x, t_y), RADIUS_COVERAGE, color='orange', fill=True, alpha=0.3)
51
+ ax.add_patch(circle)
52
+
53
+ def init():
54
+ ax.set_xlim([0, 100])
55
+ ax.set_ylim([0, 100])
56
+ ax.set_title(f'Warehouse Env', fontweight="bold", size=10)
57
+ return ln1,
58
+
59
+ def update(frame):
60
+ # for each frame, update the data stored on each artist.
61
+ x = [path[frame][0]]
62
+ y = [path[frame][1]]
63
+
64
+ ln1.set_data(x, y)
65
+ return ln1,
66
+
67
+ ani = animation.FuncAnimation(fig, update, frames=40,
68
+ init_func=init, blit=True, repeat=False)
69
+ # plt.show()
70
+
71
+ # Save to MP4
72
+ temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
73
+ ani.save(temp_video.name, writer='ffmpeg', fps=20)
74
+ plt.close(fig)
75
+ return temp_video.name
76
+
77
+
78
+ def main():
79
+ create_animation()
80
+
81
+
82
+ if __name__ == '__main__':
83
+ main()
draft_gradio_update_example.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # Function to validate the input and enable/disable the button
4
+ def validate_input(text):
5
+ error_msg = ""
6
+ button_state = False
7
+
8
+ if len(text.strip()) < 3:
9
+ error_msg = "_Input_ must be at least 3 characters."
10
+ button_state = False
11
+ else:
12
+ button_state = True
13
+
14
+ return gr.update(value=error_msg), gr.update(interactive=button_state)
15
+
16
+ # Function that runs when the button is clicked
17
+ def on_submit(text):
18
+ return f"Processed: {text.strip()}"
19
+
20
+ with gr.Blocks() as demo:
21
+ gr.Markdown("### Input Validation Example")
22
+
23
+ inp = gr.Textbox(label="Enter something")
24
+ validation = gr.Label(value="", visible=True)
25
+ btn = gr.Button("Submit", interactive=False)
26
+ out = gr.Textbox(label="Output", interactive=False)
27
+
28
+ # When the input changes, validate it and enable/disable the button
29
+ inp.change(validate_input, inputs=inp, outputs=[validation, btn])
30
+
31
+ # When the button is clicked, process the input
32
+ btn.click(on_submit, inputs=inp, outputs=out)
33
+
34
+ demo.launch()
requirements.txt CHANGED
@@ -1 +1,11 @@
1
  gradio
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
+
3
+ numpy
4
+ matplotlib
5
+ torch
6
+ torchaudio
7
+ transformers
8
+ stable_baselines3
9
+ gymnasium
10
+ vmas
11
+ datasets
sample.wav ADDED
Binary file (72.7 kB). View file
 
train_agent.py CHANGED
@@ -37,7 +37,7 @@ def train_func(alg_name='PPO'):
37
 
38
 
39
 
40
- def exec_func(alg_name='PPO', model_name=None):
41
  env = WarehouseEnv(render_mode='human')
42
  if alg_name == 'PPO':
43
  model_name = "ppo_warehouse" if model_name is None else model_name
@@ -50,7 +50,7 @@ def exec_func(alg_name='PPO', model_name=None):
50
  # vec_env = model.get_env()
51
  obs, info = env.reset()
52
  while True:
53
- action, _states = model.predict(obs)
54
  obs, rewards, done, trunc, info = env.step(action)
55
  env.render()
56
  if done or trunc:
@@ -60,7 +60,7 @@ def exec_func(alg_name='PPO', model_name=None):
60
  def main():
61
  # alg_name = 'PPO'
62
  alg_name = 'SAC'
63
- model_name = 'sac_warehouse_working_v1'
64
  # train_func(alg_name)
65
  exec_func(alg_name=alg_name, model_name=model_name)
66
 
 
37
 
38
 
39
 
40
+ def exec_func(alg_name='SAC', model_name=None):
41
  env = WarehouseEnv(render_mode='human')
42
  if alg_name == 'PPO':
43
  model_name = "ppo_warehouse" if model_name is None else model_name
 
50
  # vec_env = model.get_env()
51
  obs, info = env.reset()
52
  while True:
53
+ action, _ = model.predict(obs)
54
  obs, rewards, done, trunc, info = env.step(action)
55
  env.render()
56
  if done or trunc:
 
60
  def main():
61
  # alg_name = 'PPO'
62
  alg_name = 'SAC'
63
+ model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
64
  # train_func(alg_name)
65
  exec_func(alg_name=alg_name, model_name=model_name)
66
 
warehouse_env.py CHANGED
@@ -55,13 +55,21 @@ class WarehouseEnv(gym.Env):
55
  def rel_y(self) -> int:
56
  return self.agent_y - self.goal_y
57
 
58
- def reset(self, seed=None, options=None):
59
- self.agent_x = np.random.uniform(0, self.SIDE)
60
- self.agent_y = np.random.uniform(0, self.SIDE)
61
- # self.agent_x = 50.0
62
- # self.agent_y = 50.0
63
- self.goal_x = np.random.uniform(0, self.SIDE)
64
- self.goal_y = np.random.uniform(0, self.SIDE)
 
 
 
 
 
 
 
 
65
  self.step_counter = 0
66
  self.terminated = False
67
  self.truncated = False
 
55
  def rel_y(self) -> int:
56
  return self.agent_y - self.goal_y
57
 
58
+ def reset(self, seed=None, options=None, agent_x=None, agent_y=None, goal_x=None, goal_y=None):
59
+ if agent_x is None:
60
+ self.agent_x = np.random.uniform(0, self.SIDE)
61
+ self.agent_y = np.random.uniform(0, self.SIDE)
62
+ # self.agent_x = 50.0
63
+ # self.agent_y = 50.0
64
+ else:
65
+ self.agent_x = agent_x
66
+ self.agent_y = agent_y
67
+ if goal_x is None:
68
+ self.goal_x = np.random.uniform(0, self.SIDE)
69
+ self.goal_y = np.random.uniform(0, self.SIDE)
70
+ else:
71
+ self.goal_x = goal_x
72
+ self.goal_y = goal_y
73
  self.step_counter = 0
74
  self.terminated = False
75
  self.truncated = False