btrkeks's picture
Add transformers-compatible artifacts + model card
b529f8a verified
"""
Encoder loading factory for Transcoda model.
This module provides a flexible factory pattern for loading different vision encoder
backbones (ConvNeXt, ResNet, Swin, etc.) through a unified interface.
"""
import torch.nn as nn
from loguru import logger
from transformers import AutoModel
from .configuration_transcoda import TranscodaConfig
class EncoderLoader:
"""Factory class for loading vision encoders with automatic configuration detection."""
@staticmethod
def load(config: TranscodaConfig) -> tuple[nn.Module, int]:
"""
Load an encoder based on the specified provider.
Args:
config: Transcoda model configuration
Returns:
tuple: (encoder module, encoder output dimension)
Raises:
ValueError: If the encoder provider is not supported
"""
provider = config.encoder_provider.lower()
if provider == "transformers":
encoder, output_dim = EncoderLoader._load_transformers(config)
else:
raise ValueError(
f"Unsupported encoder provider: {config.encoder_provider}. "
f"Supported providers: ['transformers']"
)
# Apply freezing if configured
if config.freeze_encoder_stages > 0:
EncoderLoader._freeze_encoder_stages(encoder, config.freeze_encoder_stages)
logger.info(
f"Loaded encoder '{config.encoder_model_name_or_path}' "
f"via {provider} with output dimension {output_dim}"
)
return encoder, output_dim
@staticmethod
def _load_transformers(config: TranscodaConfig) -> tuple[nn.Module, int]:
"""
Load encoder using HuggingFace Transformers AutoModel.
Args:
config: Transcoda model configuration
Returns:
tuple: (encoder module, encoder output dimension)
"""
# Load encoder using AutoModel to support various architectures
encoder = AutoModel.from_pretrained(config.encoder_model_name_or_path)
# Detect output dimension from encoder config
output_dim = EncoderLoader._detect_encoder_output_dim(encoder.config)
logger.info(
f"Loaded transformers encoder: {config.encoder_model_name_or_path}, "
f"output_dim={output_dim}"
)
return encoder, output_dim
@staticmethod
def _detect_encoder_output_dim(encoder_config) -> int:
"""
Heuristic to detect the encoder's output dimension.
Different encoder architectures expose their output dimension via different
config attributes. This method checks common attribute names in order of
preference.
Args:
encoder_config: HuggingFace model config object
Returns:
int: Detected output dimension
Raises:
ValueError: If output dimension cannot be detected
"""
# Try common attribute names for output dimension
# Order matters: try most specific first
# ConvNeXt, ResNet, some CNNs: hidden_sizes[-1]
if hasattr(encoder_config, "hidden_sizes") and encoder_config.hidden_sizes:
output_dim = encoder_config.hidden_sizes[-1]
return output_dim
# SAM: nested vision_config.hidden_size
if hasattr(encoder_config, "vision_config") and hasattr(
encoder_config.vision_config, "hidden_size"
):
output_dim = encoder_config.vision_config.hidden_size
return output_dim
# BERT, ViT, many transformers: hidden_size
if hasattr(encoder_config, "hidden_size"):
output_dim = encoder_config.hidden_size
return output_dim
# Some CNN models: num_channels
if hasattr(encoder_config, "num_channels"):
output_dim = encoder_config.num_channels
return output_dim
# Swin Transformer: embed_dim
if hasattr(encoder_config, "embed_dim"):
output_dim = encoder_config.embed_dim
return output_dim
# If we can't detect it, raise an error
raise ValueError(
f"Could not detect encoder output dimension from config. "
f"Available config attributes: {dir(encoder_config)}"
)
@staticmethod
def _freeze_encoder_stages(encoder: nn.Module, num_stages: int) -> None:
"""
Freeze the embedding layer and first N stages of the encoder.
Currently implements ConvNeXtV2-specific freezing logic. Future versions
may support architecture detection and per-architecture freezing strategies.
Args:
encoder: The encoder module to freeze
num_stages: Number of encoder stages to freeze (0-indexed)
"""
# ConvNeXt-specific freezing logic
# Assumes encoder has: encoder.embeddings and encoder.encoder.stages[i]
# Freeze embeddings
if hasattr(encoder, "embeddings"):
for param in encoder.embeddings.parameters():
param.requires_grad = False
logger.info("Froze encoder embeddings.")
else:
logger.warning(
"Encoder does not have 'embeddings' attribute. Skipping embeddings freezing."
)
# Freeze early stages
if hasattr(encoder, "encoder") and hasattr(encoder.encoder, "stages"):
stages = encoder.encoder.stages
for i in range(num_stages):
if i < len(stages):
for param in stages[i].parameters():
param.requires_grad = False
logger.info(f"Froze encoder stage {i}.")
else:
logger.warning(
f"Requested freezing stage {i}, but encoder only has "
f"{len(stages)} stages. Skipping."
)
else:
logger.warning(
"Encoder does not have 'encoder.stages' structure. "
"Cannot freeze stages. This is expected for non-ConvNeXt architectures."
)