root commited on
Commit
4d6d57e
·
1 Parent(s): eb8bfb7

compatible with L40

Browse files
Files changed (2) hide show
  1. Dockerfile +5 -0
  2. vllm_hacked/v1/sample/sampler.py +3 -4
Dockerfile CHANGED
@@ -7,6 +7,11 @@ RUN apt-get update && \
7
  git lfs install && \
8
  rm -rf /var/lib/apt/lists/*
9
 
 
 
 
 
 
10
  RUN useradd -m -u 1000 user
11
  USER user
12
  ENV PATH="/home/user/.local/bin:$PATH"
 
7
  git lfs install && \
8
  rm -rf /var/lib/apt/lists/*
9
 
10
+ COPY ./vllm_hacked/model_executor/models/llama.py /opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/llama.py
11
+ COPY ./vllm_hacked/v1/sample/sampler.py /opt/conda/lib/python3.11/site-packages/vllm/v1/sample/sampler.py
12
+ COPY ./vllm_hacked/v1/sample/metadata.py /opt/conda/lib/python3.11/site-packages/vllm/v1/sample/metadata.py
13
+ COPY ./vllm_hacked/sampling_params.py /opt/conda/lib/python3.11/site-packages/vllm/sampling_params.py
14
+
15
  RUN useradd -m -u 1000 user
16
  USER user
17
  ENV PATH="/home/user/.local/bin:$PATH"
vllm_hacked/v1/sample/sampler.py CHANGED
@@ -187,10 +187,9 @@ class Sampler(nn.Module):
187
  # Avoid division by zero if there are greedy requests.
188
  if not all_random:
189
  temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
190
- try:
191
- return logits.div_(temp.view(-1, 1))
192
- except:
193
- return logits.div_(temp.unsqueeze(dim=1))
194
 
195
  def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
196
  return logits.argmax(dim=-1).view(-1)
 
187
  # Avoid division by zero if there are greedy requests.
188
  if not all_random:
189
  temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
190
+ if temp.dim() < logits.dim():
191
+ temp = temp.view([-1] + [1] * (logits.dim() - 1))
192
+ return logits / temp
 
193
 
194
  def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
195
  return logits.argmax(dim=-1).view(-1)