EAVAE: Explainable Author-Variational Autoencoder

A PyTorch Lightning implementation of an Explainable Author-Variational Autoencoder (EAVAE) for learning disentangled style and content representations in text. This model learns to separate an author's writing style from the semantic content, enabling applications in authorship verification, style transfer, and text generation with controlled stylistic attributes.

🎯 Overview

EAVAE is a neural architecture that combines:

  • Style Encoder: Captures author-specific writing patterns (e.g., word choice, sentence structure)
  • Content Encoder: Extracts semantic meaning independent of style
  • Generator: Reconstructs text conditioned on both style and content representations
  • VAE Framework: Uses variational autoencoders for regularized latent space learning

The model achieves disentanglement through adversarial discriminators and mutual information regularization, ensuring that style and content representations remain independent. The model is published at Huggingface

πŸ—οΈ Architecture

Input Text
    β”œβ”€> Style Encoder (Bidirectional Qwen) ─> Style VAE ─> Style Latent (z_s)
    └─> Content Encoder (GTE-Qwen) ────────> Content VAE ─> Content Latent (z_c)
                                                    ↓
                                          [z_s βŠ• z_c] β†’ Generator (Qwen)
                                                    ↓
                                            Reconstructed Text

Key Components

  1. Style Encoder (src/model/encoder.py)

    • Bidirectional transformer (Qwen2/Qwen3) for capturing style patterns
    • VAE bottleneck for regularization
    • Configurable with LoRA for efficient fine-tuning
  2. Content Encoder (src/model/encoder.py)

    • Dense retrieval model (e.g., GTE-Qwen2-1.5B)
    • Extracts semantic representations
    • Independent from stylistic variations
  3. Generator (src/model/generator.py)

    • Causal language model (Qwen2.5/Qwen3)
    • Conditioned on concatenated style and content embeddings
    • Optional LoRA adaptation
  4. Discriminators (src/model/model.py)

    • Style discriminator: Encourages content latents to be style-invariant
    • Content discriminator: Encourages style latents to be content-invariant

πŸ“¦ Installation

Requirements

  • Python 3.11+
  • CUDA 12.0+ (for GPU training)
  • 80GB+ GPU memory recommended for full model training

Setup

# Clone the repository
git clone https://github.com/yourusername/avae.git
cd avae

# Create conda environment
conda create -n avae python=3.10
conda activate avae

# Install dependencies
pip install -r requirements.txt

Key Dependencies

  • torch>=2.0.0
  • lightning>=2.5.0
  • transformers>=4.36.0
  • flagembedding>=1.3.4
  • peft (for LoRA)
  • wandb (for experiment tracking)

πŸš€ Quick Start

1. Data

You can found the data for Petrain Style Encoder at: [Pretrain_data]

For the EAVAE training, you can found the data at: [EAVAE_data]

πŸ“Š Datasets

Training Datasets

EAVAE is trained on diverse multi-author corpora:

  • Reddit
  • Blog Authorship Corpus
  • Amazon Reviews
  • Goodreads Reviews
  • IMDb Reviews
  • New York Times Comments
  • Yelp Reviews
  • News Articles (RealNews)
  • Wikipedia
  • And more (see Pretrain_data)

Training Datasets

Evaluation Benchmarks

  • HRS (HIATUS Reddit Stories): multi-genre authorship verification
  • MUD (Multi-User Detection): Reddit-based authorship attribution
  • PAN20/PAN21: PAN competition authorship verification tasks
  • Amazon Reviews: Product review authorship verification
  • M4: Multi-domain for AI-generated text detection

Evaluation Results

βš™οΈ Configuration

All experiments are configured via YAML files in scripts/configs/.

Main Configuration Parameters

Model Architecture

# Encoder models
style_encoder_model_name_or_path: "Qwen/Qwen2-1.5B"
content_encoder_model_name_or_path: "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
generator_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct"

# Architecture settings
embedding_dim: 1536
pooling_method: mean
dropout_prob: 0.1
use_vae: true

# LoRA settings
use_lora: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1

Loss Weights

reconstruction_loss_weight: 1.0      # Reconstruction quality
vae_loss_weight: 1.0e-5              # KL divergence regularization
style_discriminator_loss_weight: 1.0  # Style-invariant content
content_discriminator_loss_weight: 1.0 # Content-invariant style
constraint_loss_weight: 0.1          # Consistency with pretrained encoders

Training Hyperparameters

learning_rate: 5.0e-5
max_epochs: 3
max_steps: 40000
global_batch_size: 32
effective_batch_size: 32  # With gradient accumulation
grad_norm_clip: 1.0
warmpup_proportion: 0.1
weight_decay: 0.0

πŸ”¬ Model Details

Disentanglement Objectives

  1. Reconstruction Loss: Measures how well the generator reconstructs the input

    L_recon = -log P(x | z_s, z_c)
    
  2. VAE KL Loss: Regularizes latent distributions

    L_KL = KL(q(z|x) || p(z))
    
  3. Adversarial Discriminator Losses:

    • Style discriminator tries to predict style from content latents (minimize for content encoder)
    • Content discriminator tries to predict content from style latents (minimize for style encoder)
  4. Constraint Loss: Maintains consistency with pretrained reference encoders

  5. Mutual Information Regularization (optional): Further encourages independence

Training Strategy

  • FSDP (Fully Sharded Data Parallel): Efficient distributed training
  • Mixed Precision (BF16): Faster training with lower memory
  • Gradient Checkpointing: Trade compute for memory
  • Cyclic KL Annealing: Gradually increases KL weight for stable training

πŸ“ License

This project is licensed under the MIT License - see the LICENSE file for details.

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Development Setup

# Install development dependencies
pip install -r requirements.txt

# Run tests (if available)
pytest tests/

πŸ› Troubleshooting

Common Issues

  1. OOM (Out of Memory)

    • Reduce global_batch_size
    • Enable use_cpu_offload: true
    • Use gradient accumulation (effective_batch_size > global_batch_size)
    • Enable activation checkpointing
  2. FSDP Errors

    • Ensure all model components are properly wrapped
    • Check that nodes and devices match your hardware
    • Try switching to DDP strategy for debugging
  3. NaN Loss

    • Reduce learning rate
    • Increase warmup steps
    • Check loss weight balancing
    • Enable gradient clipping

πŸ“§ Contact

For questions or issues, please open an issue on GitHub or contact [[email protected]].

πŸ™ Acknowledgments


Note: This is research code. For production use, additional testing and optimization may be required.

Downloads last month
9
Safetensors
Model size
5B params
Tensor type
F32
Β·
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support