sentence-transformers/stsb
Viewer • Updated • 8.63k • 21.3k • 26
How to use omkar334/reranker-distilroberta-base-stsb with sentence-transformers:
from sentence_transformers import CrossEncoder
model = CrossEncoder("omkar334/reranker-distilroberta-base-stsb")
query = "Which planet is known as the Red Planet?"
passages = [
"Venus is often called Earth's twin because of its similar size and proximity.",
"Mars, known for its reddish appearance, is often referred to as the Red Planet.",
"Jupiter, the largest planet in our solar system, has a prominent red spot.",
"Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
scores = model.predict([(query, passage) for passage in passages])
print(scores)This is a Cross Encoder model finetuned from distilbert/distilroberta-base on the stsb dataset using the sentence-transformers library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.
CrossEncoder(
(0): Transformer({'transformer_task': 'sequence-classification', 'modality_config': {'text': {'method': 'forward', 'method_output_name': 'logits'}}, 'module_output_name': 'scores', 'architecture': 'RobertaForSequenceClassification'})
)
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import CrossEncoder
# Download from the 🤗 Hub
model = CrossEncoder("omkar334/reranker-distilroberta-base-stsb")
# Get scores for pairs of inputs
pairs = [
['A man with a hard hat is dancing.', 'A man wearing a hard hat is dancing.'],
['A young child is riding a horse.', 'A child is riding a horse.'],
['A man is feeding a mouse to a snake.', 'The man is feeding a mouse to the snake.'],
['A woman is playing the guitar.', 'A man is playing guitar.'],
['A woman is playing the flute.', 'A man is playing a flute.'],
]
scores = model.predict(pairs)
print(scores)
# [0.9598 0.9533 0.9566 0.3766 0.4535]
# Or rank different texts based on similarity to a single text
ranks = model.rank(
'A man with a hard hat is dancing.',
[
'A man wearing a hard hat is dancing.',
'A child is riding a horse.',
'The man is feeding a mouse to the snake.',
'A man is playing guitar.',
'A man is playing a flute.',
]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
stsb-validation and stsb-testCrossEncoderCorrelationEvaluator| Metric | stsb-validation | stsb-test |
|---|---|---|
| pearson | 0.8864 | 0.8504 |
| spearman | 0.8838 | 0.8404 |
sentence1, sentence2, and score| sentence1 | sentence2 | score | |
|---|---|---|---|
| type | string | string | float |
| modality | text | text | |
| details |
|
|
|
| sentence1 | sentence2 | score |
|---|---|---|
A plane is taking off. |
An air plane is taking off. |
1.0 |
A man is playing a large flute. |
A man is playing a flute. |
0.76 |
A man is spreading shreded cheese on a pizza. |
A man is spreading shredded cheese on an uncooked pizza. |
0.76 |
BinaryCrossEntropyLoss with these parameters:{
"activation_fn": "torch.nn.modules.linear.Identity",
"pos_weight": null
}
sentence1, sentence2, and score| sentence1 | sentence2 | score | |
|---|---|---|---|
| type | string | string | float |
| modality | text | text | |
| details |
|
|
|
| sentence1 | sentence2 | score |
|---|---|---|
A man with a hard hat is dancing. |
A man wearing a hard hat is dancing. |
1.0 |
A young child is riding a horse. |
A child is riding a horse. |
0.95 |
A man is feeding a mouse to a snake. |
The man is feeding a mouse to the snake. |
1.0 |
BinaryCrossEntropyLoss with these parameters:{
"activation_fn": "torch.nn.modules.linear.Identity",
"pos_weight": null
}
per_device_train_batch_size: 64num_train_epochs: 4warmup_steps: 0.1bf16: Trueper_device_eval_batch_size: 64per_device_train_batch_size: 64num_train_epochs: 4max_steps: -1learning_rate: 5e-05lr_scheduler_type: linearlr_scheduler_kwargs: Nonewarmup_steps: 0.1optim: adamw_torch_fusedoptim_args: Noneweight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08optim_target_modules: Nonegradient_accumulation_steps: 1average_tokens_across_devices: Truemax_grad_norm: 1.0label_smoothing_factor: 0.0bf16: Truefp16: Falsebf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonegradient_checkpointing: Falsegradient_checkpointing_kwargs: Nonetorch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Noneuse_liger_kernel: Falseliger_kernel_config: Noneuse_cache: Falseneftune_noise_alpha: Nonetorch_empty_cache_steps: Noneauto_find_batch_size: Falselog_on_each_node: Truelogging_nan_inf_filter: Trueinclude_num_input_tokens_seen: nolog_level: passivelog_level_replica: warningdisable_tqdm: Falseproject: huggingfacetrackio_space_id: Nonetrackio_bucket_id: Nonetrackio_static_space_id: Noneper_device_eval_batch_size: 64prediction_loss_only: Trueeval_on_start: Falseeval_do_concat_batches: Trueeval_use_gather_object: Falseeval_accumulation_steps: Noneinclude_for_metrics: []batch_eval_metrics: Falsesave_only_model: Falsesave_on_each_node: Falseenable_jit_checkpoint: Falsepush_to_hub: Falsehub_private_repo: Nonehub_model_id: Nonehub_strategy: every_savehub_always_push: Falsehub_revision: Noneload_best_model_at_end: Falseignore_data_skip: Falserestore_callback_states_from_checkpoint: Falsefull_determinism: Falseseed: 42data_seed: Noneuse_cpu: Falseaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}parallelism_config: Nonedataloader_drop_last: Falsedataloader_num_workers: 0dataloader_pin_memory: Truedataloader_persistent_workers: Falsedataloader_prefetch_factor: Noneremove_unused_columns: Truelabel_names: Nonetrain_sampling_strategy: randomlength_column_name: lengthddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falseddp_static_graph: Noneddp_backend: Noneddp_timeout: 1800fsdp: []fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}deepspeed: Nonedebug: []skip_memory_metrics: Truedo_predict: Falseresume_from_checkpoint: Nonewarmup_ratio: Nonelocal_rank: -1prompts: Nonebatch_sampler: batch_samplermulti_dataset_batch_sampler: proportionalrouter_mapping: {}learning_rate_mapping: {}| Epoch | Step | Training Loss | Validation Loss | stsb-validation_spearman | stsb-test_spearman |
|---|---|---|---|---|---|
| -1 | -1 | - | - | -0.0362 | - |
| 0.2222 | 20 | 0.6909 | - | - | - |
| 0.4444 | 40 | 0.6506 | - | - | - |
| 0.6667 | 60 | 0.5969 | - | - | - |
| 0.8889 | 80 | 0.5680 | 0.5461 | 0.8552 | - |
| 1.1111 | 100 | 0.5551 | - | - | - |
| 1.3333 | 120 | 0.5379 | - | - | - |
| 1.5556 | 140 | 0.5449 | - | - | - |
| 1.7778 | 160 | 0.5443 | 0.5342 | 0.8777 | - |
| 2.0 | 180 | 0.5373 | - | - | - |
| 2.2222 | 200 | 0.5287 | - | - | - |
| 2.4444 | 220 | 0.5248 | - | - | - |
| 2.6667 | 240 | 0.5283 | 0.5383 | 0.8785 | - |
| 2.8889 | 260 | 0.5251 | - | - | - |
| 3.1111 | 280 | 0.5156 | - | - | - |
| 3.3333 | 300 | 0.5093 | - | - | - |
| 3.5556 | 320 | 0.5164 | 0.5369 | 0.8824 | - |
| 3.7778 | 340 | 0.5152 | - | - | - |
| 4.0 | 360 | 0.5208 | 0.5331 | 0.8838 | - |
| -1 | -1 | - | - | - | 0.8404 |
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
Base model
distilbert/distilroberta-base
from sentence_transformers import CrossEncoder model = CrossEncoder("omkar334/reranker-distilroberta-base-stsb") query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores)