How to use from the
Use from the
sentence-transformers library
from sentence_transformers import CrossEncoder

model = CrossEncoder("omkar334/reranker-distilroberta-base-stsb")

query = "Which planet is known as the Red Planet?"
passages = [
	"Venus is often called Earth's twin because of its similar size and proximity.",
	"Mars, known for its reddish appearance, is often referred to as the Red Planet.",
	"Jupiter, the largest planet in our solar system, has a prominent red spot.",
	"Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]

scores = model.predict([(query, passage) for passage in passages])
print(scores)

CrossEncoder based on distilbert/distilroberta-base

This is a Cross Encoder model finetuned from distilbert/distilroberta-base on the stsb dataset using the sentence-transformers library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.

Model Details

Model Description

  • Model Type: Cross Encoder
  • Base model: distilbert/distilroberta-base
  • Maximum Sequence Length: 512 tokens
  • Number of Output Labels: 1 label
  • Supported Modality: Text
  • Training Dataset:
  • Language: en

Model Sources

Full Model Architecture

CrossEncoder(
  (0): Transformer({'transformer_task': 'sequence-classification', 'modality_config': {'text': {'method': 'forward', 'method_output_name': 'logits'}}, 'module_output_name': 'scores', 'architecture': 'RobertaForSequenceClassification'})
)

Usage

Direct Usage (Sentence Transformers)

First install the Sentence Transformers library:

pip install -U sentence-transformers

Then you can load this model and run inference.

from sentence_transformers import CrossEncoder

# Download from the 🤗 Hub
model = CrossEncoder("omkar334/reranker-distilroberta-base-stsb")
# Get scores for pairs of inputs
pairs = [
    ['A man with a hard hat is dancing.', 'A man wearing a hard hat is dancing.'],
    ['A young child is riding a horse.', 'A child is riding a horse.'],
    ['A man is feeding a mouse to a snake.', 'The man is feeding a mouse to the snake.'],
    ['A woman is playing the guitar.', 'A man is playing guitar.'],
    ['A woman is playing the flute.', 'A man is playing a flute.'],
]
scores = model.predict(pairs)
print(scores)
# [0.9598 0.9533 0.9566 0.3766 0.4535]

# Or rank different texts based on similarity to a single text
ranks = model.rank(
    'A man with a hard hat is dancing.',
    [
        'A man wearing a hard hat is dancing.',
        'A child is riding a horse.',
        'The man is feeding a mouse to the snake.',
        'A man is playing guitar.',
        'A man is playing a flute.',
    ]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]

Evaluation

Metrics

Cross Encoder Correlation

Metric stsb-validation stsb-test
pearson 0.8864 0.8504
spearman 0.8838 0.8404

Training Details

Training Dataset

stsb

  • Dataset: stsb at ab7a5ac
  • Size: 5,749 training samples
  • Columns: sentence1, sentence2, and score
  • Approximate statistics based on the first 100 samples:
    sentence1 sentence2 score
    type string string float
    modality text text
    details
    • min: 7 tokens
    • mean: 9.49 tokens
    • max: 14 tokens
    • min: 7 tokens
    • mean: 9.61 tokens
    • max: 17 tokens
    • min: 0.1
    • mean: 0.66
    • max: 1.0
  • Samples:
    sentence1 sentence2 score
    A plane is taking off. An air plane is taking off. 1.0
    A man is playing a large flute. A man is playing a flute. 0.76
    A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncooked pizza. 0.76
  • Loss: BinaryCrossEntropyLoss with these parameters:
    {
        "activation_fn": "torch.nn.modules.linear.Identity",
        "pos_weight": null
    }
    

Evaluation Dataset

stsb

  • Dataset: stsb at ab7a5ac
  • Size: 1,500 evaluation samples
  • Columns: sentence1, sentence2, and score
  • Approximate statistics based on the first 100 samples:
    sentence1 sentence2 score
    type string string float
    modality text text
    details
    • min: 7 tokens
    • mean: 10.04 tokens
    • max: 19 tokens
    • min: 7 tokens
    • mean: 9.98 tokens
    • max: 18 tokens
    • min: 0.0
    • mean: 0.53
    • max: 1.0
  • Samples:
    sentence1 sentence2 score
    A man with a hard hat is dancing. A man wearing a hard hat is dancing. 1.0
    A young child is riding a horse. A child is riding a horse. 0.95
    A man is feeding a mouse to a snake. The man is feeding a mouse to the snake. 1.0
  • Loss: BinaryCrossEntropyLoss with these parameters:
    {
        "activation_fn": "torch.nn.modules.linear.Identity",
        "pos_weight": null
    }
    

Training Hyperparameters

Non-Default Hyperparameters

  • per_device_train_batch_size: 64
  • num_train_epochs: 4
  • warmup_steps: 0.1
  • bf16: True
  • per_device_eval_batch_size: 64

All Hyperparameters

Click to expand
  • per_device_train_batch_size: 64
  • num_train_epochs: 4
  • max_steps: -1
  • learning_rate: 5e-05
  • lr_scheduler_type: linear
  • lr_scheduler_kwargs: None
  • warmup_steps: 0.1
  • optim: adamw_torch_fused
  • optim_args: None
  • weight_decay: 0.0
  • adam_beta1: 0.9
  • adam_beta2: 0.999
  • adam_epsilon: 1e-08
  • optim_target_modules: None
  • gradient_accumulation_steps: 1
  • average_tokens_across_devices: True
  • max_grad_norm: 1.0
  • label_smoothing_factor: 0.0
  • bf16: True
  • fp16: False
  • bf16_full_eval: False
  • fp16_full_eval: False
  • tf32: None
  • gradient_checkpointing: False
  • gradient_checkpointing_kwargs: None
  • torch_compile: False
  • torch_compile_backend: None
  • torch_compile_mode: None
  • use_liger_kernel: False
  • liger_kernel_config: None
  • use_cache: False
  • neftune_noise_alpha: None
  • torch_empty_cache_steps: None
  • auto_find_batch_size: False
  • log_on_each_node: True
  • logging_nan_inf_filter: True
  • include_num_input_tokens_seen: no
  • log_level: passive
  • log_level_replica: warning
  • disable_tqdm: False
  • project: huggingface
  • trackio_space_id: None
  • trackio_bucket_id: None
  • trackio_static_space_id: None
  • per_device_eval_batch_size: 64
  • prediction_loss_only: True
  • eval_on_start: False
  • eval_do_concat_batches: True
  • eval_use_gather_object: False
  • eval_accumulation_steps: None
  • include_for_metrics: []
  • batch_eval_metrics: False
  • save_only_model: False
  • save_on_each_node: False
  • enable_jit_checkpoint: False
  • push_to_hub: False
  • hub_private_repo: None
  • hub_model_id: None
  • hub_strategy: every_save
  • hub_always_push: False
  • hub_revision: None
  • load_best_model_at_end: False
  • ignore_data_skip: False
  • restore_callback_states_from_checkpoint: False
  • full_determinism: False
  • seed: 42
  • data_seed: None
  • use_cpu: False
  • accelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
  • parallelism_config: None
  • dataloader_drop_last: False
  • dataloader_num_workers: 0
  • dataloader_pin_memory: True
  • dataloader_persistent_workers: False
  • dataloader_prefetch_factor: None
  • remove_unused_columns: True
  • label_names: None
  • train_sampling_strategy: random
  • length_column_name: length
  • ddp_find_unused_parameters: None
  • ddp_bucket_cap_mb: None
  • ddp_broadcast_buffers: False
  • ddp_static_graph: None
  • ddp_backend: None
  • ddp_timeout: 1800
  • fsdp: []
  • fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
  • deepspeed: None
  • debug: []
  • skip_memory_metrics: True
  • do_predict: False
  • resume_from_checkpoint: None
  • warmup_ratio: None
  • local_rank: -1
  • prompts: None
  • batch_sampler: batch_sampler
  • multi_dataset_batch_sampler: proportional
  • router_mapping: {}
  • learning_rate_mapping: {}

Training Logs

Epoch Step Training Loss Validation Loss stsb-validation_spearman stsb-test_spearman
-1 -1 - - -0.0362 -
0.2222 20 0.6909 - - -
0.4444 40 0.6506 - - -
0.6667 60 0.5969 - - -
0.8889 80 0.5680 0.5461 0.8552 -
1.1111 100 0.5551 - - -
1.3333 120 0.5379 - - -
1.5556 140 0.5449 - - -
1.7778 160 0.5443 0.5342 0.8777 -
2.0 180 0.5373 - - -
2.2222 200 0.5287 - - -
2.4444 220 0.5248 - - -
2.6667 240 0.5283 0.5383 0.8785 -
2.8889 260 0.5251 - - -
3.1111 280 0.5156 - - -
3.3333 300 0.5093 - - -
3.5556 320 0.5164 0.5369 0.8824 -
3.7778 340 0.5152 - - -
4.0 360 0.5208 0.5331 0.8838 -
-1 -1 - - - 0.8404

Training Time

  • Training: 3.2 minutes
  • Evaluation: 15.8 seconds
  • Total: 3.5 minutes

Framework Versions

  • Python: 3.11.14
  • Sentence Transformers: 5.6.0.dev0
  • Transformers: 5.9.0
  • PyTorch: 2.12.0
  • Accelerate: 1.13.0
  • Datasets: 4.8.5
  • Tokenizers: 0.22.2

Additional Resources

Citation

BibTeX

Sentence Transformers

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}
Downloads last month
15
Safetensors
Model size
82.1M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for omkar334/reranker-distilroberta-base-stsb

Finetuned
(777)
this model

Dataset used to train omkar334/reranker-distilroberta-base-stsb

Paper for omkar334/reranker-distilroberta-base-stsb

Evaluation results