Image-Text-to-Text
Transformers
Safetensors
llava_test_time_registers
text-generation

LLaVA-Llama-3-8b with Test-Time Register

Register tokens in ViTs were introduced as learnable tokens in Vision Transformers Need Registers to mitigate artifacts in intermediate feature maps. In Vision Transformers Don't Need Trained Registers, we introduced a training-free method to create registers. These test-time registers serve a similar purpose as the original trained registers, but can be added post-hoc to any ViT to mitigate artifacts, enhance model interpretability, and modestly improve downstream performance in tasks such as segmentation, depth estimation, etc.

Model description

The base model is LLaVA-Llama-3-8b v1.1. With test-time registers, the model's internal representations are cleaner and can be used to better debug model behavior. We visualize the attention of the language model's generated response to visual tokens below (zoom in). We run evaluation using VLMEvalKit with the environment from here (using transformers==4.37.0). This model is intended to be used with this repo. The model can also be used for fine-tuning or other downstream tasks.

drawing
Model Avg. HallusionBench MMVet MMMU Val OCRBench MMStar MathVista AI2D Test MMBenchv1.1
LLaVA-Llama-3-8B v1.1 46.2 28.6 33.4 40.4 41.6 46.3 40.9 69.9 68.5
w/test-time register 46.2 29.4 33.9 40.1 41.3 46.4 41.3 69.4 68.0

Quick Start

import torch
from transformers import AutoProcessor
from PIL import Image
from huggingface_hub import snapshot_download
import sys, os

repo_path = snapshot_download("amildravid4292/llava-llama-3-8b-test-time-registers") 
sys.path.insert(0, repo_path)
from modeling_custom_llava import LlavaRegistersForConditionalGeneration

device = "cuda:0"

model = LlavaRegistersForConditionalGeneration.from_pretrained(
    "xtuner/llava-llama-3-8b-v1_1-transformers", 
    torch_dtype=torch.float16,
    output_attentions=True
).to(device)

# user original processor
processor = AutoProcessor.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")

prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\nHow many tennis balls are in the dog's mouth? Use one word.<|eot_id|>"
          "<|start_header_id|>assistant<|end_header_id|>\n\n")

# Load image
image_path = "dog_image.webp"
raw_image = Image.open(image_path)

inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16)

# model defaults to using test-time register
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=20, do_sample=False)

# To use without test-time register
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=20, do_sample=False, extra_tokens=0, neuron_dict=None)

tokenizer = processor.tokenizer
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Decoded output:", decoded_output)

Visualizing Language Model's Attention to Visual Tokens

import torch
from transformers import AutoProcessor
from PIL import Image
from huggingface_hub import snapshot_download
import sys, os

repo_path = snapshot_download("amildravid4292/llava-llama-3-8b-test-time-registers") 
sys.path.insert(0, repo_path)
from modeling_custom_llava import LlavaRegistersForConditionalGeneration

device = "cuda:0"

# language model attention capture
class AttentionCaptureModel(LlavaRegistersForConditionalGeneration):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.captured_attentions = None

    def forward(self, *args, **kwargs):
        # Capture the attention weights
        output = super().forward(*args, **kwargs)
        self.captured_attentions = output.attentions
        return output


model = AttentionCaptureModel.from_pretrained(
    "xtuner/llava-llama-3-8b-v1_1-transformers", 
    torch_dtype=torch.float16
).to(device)

# use original processor
processor = AutoProcessor.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")


prompt = ("<|start_header_id|>user<|end_header_id|>\n\n<image>\nHow many tennis balls are in the dog's mouth? Use one word.<|eot_id|>"
          "<|start_header_id|>assistant<|end_header_id|>\n\n")

# Load image
image_path = "dog_image.webp"
raw_image = Image.open(image_path)

inputs = processor(prompt, raw_image, return_tensors='pt').to(device, torch.float16)

# model defaults to using test-time register
with torch.no_grad():
    output = model.generate(**inputs, max_new_tokens=1, do_sample=False)

tokenizer = processor.tokenizer
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print("Decoded output:", decoded_output)

# get attention
atts = torch.cat(model.captured_attentions).float()
# visualize attention from answer to visual tokens
im = plt.imshow(atts.mean(0).mean(0)[-1, 5:581].cpu().reshape(24,24))
plt.axis("off")
plt.suptitle("Mean Attention Map for Answer Token ", fontsize = 20)
plt.tight_layout()
plt.colorbar(im)
plt.show()

Advanced Usage

Custom Neuron Modifications

# Override the saved neuron configuration
custom_neuron_dict = {0: [10, 20, 30]}  # Modify neurons 10,20,30 in layer 0
with torch.no_grad():
  output = model.generate(**inputs, max_new_tokens=20, do_sample=False, neuron_dict=custom_neuron_dict)

Different Register Token Counts

# Use different number of register tokens
with torch.no_grad():
  output = model.generate(**inputs, max_new_tokens=20, do_sample=False, extra_tokens=5)

BibTeX entry and citation info

@misc{jiang2025visiontransformersdontneed,
      title={Vision Transformers Don't Need Trained Registers}, 
      author={Nick Jiang and Amil Dravid and Alexei Efros and Yossi Gandelsman},
      year={2025},
      eprint={2506.08010},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2506.08010}, 
}
Downloads last month
12
Safetensors
Model size
7B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for amildravid4292/llava-llama-3-8b-test-time-registers

Finetuned
(2)
this model

Datasets used to train amildravid4292/llava-llama-3-8b-test-time-registers

Collection including amildravid4292/llava-llama-3-8b-test-time-registers