| --- |
| license: apache-2.0 |
| base_model: |
| - Qwen/Qwen3-Embedding-8B |
| pipeline_tag: sentence-similarity |
| --- |
| |
| The model of SitEmb-v1.5-Qwen3 trained with additional book notes and their corresponding underlined texts. |
|
|
| ### Transformer Usage |
| ```python |
| import torch |
| |
| from transformers import AutoTokenizer, AutoModel |
| from tqdm import tqdm |
| from more_itertools import chunked |
| |
| |
| residual = True |
| residual_factor = 0.5 |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| "Qwen/Qwen3-Embedding-8B", |
| use_fast=True, |
| padding_side='left', |
| ) |
| |
| model = AutoModel.from_pretrained( |
| "SituatedEmbedding/SitEmb-v1.5-Qwen3-note", |
| torch_dtype=torch.bfloat16, |
| device_map={"": 0}, |
| ) |
| |
| def _pooling(last_hidden_state, attention_mask, pooling, normalize, input_ids=None, match_idx=None): |
| if pooling in ['cls', 'first']: |
| reps = last_hidden_state[:, 0] |
| elif pooling in ['mean', 'avg', 'average']: |
| masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
| elif pooling in ['last', 'eos']: |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| if left_padding: |
| reps = last_hidden_state[:, -1] |
| else: |
| sequence_lengths = attention_mask.sum(dim=1) - 1 |
| batch_size = last_hidden_state.shape[0] |
| reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] |
| elif pooling == 'ext': |
| if match_idx is None: |
| # default mean |
| masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
| else: |
| for k in range(input_ids.shape[0]): |
| sep_index = input_ids[k].tolist().index(match_idx) |
| attention_mask[k][sep_index:] = 0 |
| masked_hiddens = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| reps = masked_hiddens.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
| else: |
| raise ValueError(f'unknown pooling method: {pooling}') |
| if normalize: |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| return reps |
| |
| |
| def first_eos_token_pooling( |
| last_hidden_states, |
| first_eos_position, |
| normalize, |
| ): |
| batch_size = last_hidden_states.shape[0] |
| reps = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), first_eos_position] |
| if normalize: |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| return reps |
| |
| def encode_query(tokenizer, model, pooling, queries, batch_size, normalize, max_length, residual): |
| task = "Given a search query, retrieve relevant chunks from fictions that answer the query" |
| sents = [] |
| for query in queries: |
| sents.append(get_detailed_instruct(task, query)) |
| |
| return encode_passage(tokenizer, model, pooling, sents, batch_size, normalize, max_length) |
| |
| |
| def encode_passage(tokenizer, model, pooling, passages, batch_size, normalize, max_length, residual=False): |
| pas_embs = [] |
| pas_embs_residual = [] |
| total = len(passages) // batch_size + (1 if len(passages) % batch_size != 0 else 0) |
| with tqdm(total=total) as pbar: |
| for sent_b in chunked(passages, batch_size): |
| batch_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, |
| return_tensors='pt').to(model.device) |
| if residual: |
| batch_list_dict = tokenizer(sent_b, max_length=max_length, padding=True, truncation=True, ) |
| input_ids = batch_list_dict['input_ids'] |
| attention_mask = batch_list_dict['attention_mask'] |
| max_len = len(input_ids[0]) |
| input_starts = [max_len - sum(att) for att in attention_mask] |
| eos_pos = [] |
| for ii, it in zip(input_ids, input_starts): |
| pos = ii.index(tokenizer.pad_token_id, it) |
| eos_pos.append(pos) |
| eos_pos = torch.tensor(eos_pos).to(model.device) |
| else: |
| eos_pos = None |
| outputs = model(**batch_dict) |
| pemb_ = _pooling(outputs.last_hidden_state, batch_dict['attention_mask'], pooling, normalize) |
| if residual: |
| remb_ = first_eos_token_pooling(outputs.last_hidden_state, eos_pos, normalize) |
| pas_embs_residual.append(remb_) |
| pas_embs.append(pemb_) |
| pbar.update(1) |
| pas_embs = torch.cat(pas_embs, dim=0) |
| if pas_embs_residual: |
| pas_embs_residual = torch.cat(pas_embs_residual, dim=0) |
| else: |
| pas_embs_residual = None |
| return pas_embs, pas_embs_residual |
| |
| your_query = "Your Query" |
| |
| query_hidden, _ = encode_query( |
| tokenizer, model, pooling_type="eos", queries=[your_query], |
| batch_size=8, normalize=True, max_length=8192, residual=residual, |
| ) |
| |
| passage_affix = "The context in which the chunk is situated is given below. Please encode the chunk by being aware of the context. Context:\n" |
| your_chunk = "Your Chunk" |
| your_context = "Your Context" |
| |
| candidate_hidden, candidate_hidden_residual = encode_passage( |
| tokenizer, model, pooling_type="eos", passages=[f"{your_chunk}<|endoftext|>{passage_affix}{your_context}"], |
| batch_size=4, normalize=True, max_length=8192, residual=residual, |
| ) |
| |
| query2candidate = query_hidden @ candidate_hidden.T # [num_queries, num_candidates] |
| if candidate_hidden_residual is not None: |
| query2candidate_residual = query_hidden @ candidate_hidden_residual.T |
| if residual_factor == 1.: |
| query2candidate = query2candidate_residual |
| elif residual_factor == 0.: |
| pass |
| else: |
| query2candidate = query2candidate * (1. - residual_factor) + query2candidate_residual * residual_factor |
| |
| print(query2candidate.tolist()) |
| ``` |