EAVAE: Explainable Author-Variational Autoencoder
A PyTorch Lightning implementation of an Explainable Author-Variational Autoencoder (EAVAE) for learning disentangled style and content representations in text. This model learns to separate an author's writing style from the semantic content, enabling applications in authorship verification, style transfer, and text generation with controlled stylistic attributes.
π― Overview
EAVAE is a neural architecture that combines:
- Style Encoder: Captures author-specific writing patterns (e.g., word choice, sentence structure)
- Content Encoder: Extracts semantic meaning independent of style
- Generator: Reconstructs text conditioned on both style and content representations
- VAE Framework: Uses variational autoencoders for regularized latent space learning
The model achieves disentanglement through adversarial discriminators and mutual information regularization, ensuring that style and content representations remain independent. The model is published at Huggingface
ποΈ Architecture
Input Text
ββ> Style Encoder (Bidirectional Qwen) β> Style VAE β> Style Latent (z_s)
ββ> Content Encoder (GTE-Qwen) ββββββββ> Content VAE β> Content Latent (z_c)
β
[z_s β z_c] β Generator (Qwen)
β
Reconstructed Text
Key Components
Style Encoder (
src/model/encoder.py)- Bidirectional transformer (Qwen2/Qwen3) for capturing style patterns
- VAE bottleneck for regularization
- Configurable with LoRA for efficient fine-tuning
Content Encoder (
src/model/encoder.py)- Dense retrieval model (e.g., GTE-Qwen2-1.5B)
- Extracts semantic representations
- Independent from stylistic variations
Generator (
src/model/generator.py)- Causal language model (Qwen2.5/Qwen3)
- Conditioned on concatenated style and content embeddings
- Optional LoRA adaptation
Discriminators (
src/model/model.py)- Style discriminator: Encourages content latents to be style-invariant
- Content discriminator: Encourages style latents to be content-invariant
π¦ Installation
Requirements
- Python 3.11+
- CUDA 12.0+ (for GPU training)
- 80GB+ GPU memory recommended for full model training
Setup
# Clone the repository
git clone https://github.com/yourusername/avae.git
cd avae
# Create conda environment
conda create -n avae python=3.10
conda activate avae
# Install dependencies
pip install -r requirements.txt
Key Dependencies
torch>=2.0.0lightning>=2.5.0transformers>=4.36.0flagembedding>=1.3.4peft(for LoRA)wandb(for experiment tracking)
π Quick Start
1. Data
You can found the data for Petrain Style Encoder at: [Pretrain_data]
For the EAVAE training, you can found the data at: [EAVAE_data]
π Datasets
Training Datasets
EAVAE is trained on diverse multi-author corpora:
- Blog Authorship Corpus
- Amazon Reviews
- Goodreads Reviews
- IMDb Reviews
- New York Times Comments
- Yelp Reviews
- News Articles (RealNews)
- Wikipedia
- And more (see Pretrain_data)
Evaluation Benchmarks
- HRS (HIATUS Reddit Stories): multi-genre authorship verification
- MUD (Multi-User Detection): Reddit-based authorship attribution
- PAN20/PAN21: PAN competition authorship verification tasks
- Amazon Reviews: Product review authorship verification
- M4: Multi-domain for AI-generated text detection
βοΈ Configuration
All experiments are configured via YAML files in scripts/configs/.
Main Configuration Parameters
Model Architecture
# Encoder models
style_encoder_model_name_or_path: "Qwen/Qwen2-1.5B"
content_encoder_model_name_or_path: "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
generator_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct"
# Architecture settings
embedding_dim: 1536
pooling_method: mean
dropout_prob: 0.1
use_vae: true
# LoRA settings
use_lora: false
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
Loss Weights
reconstruction_loss_weight: 1.0 # Reconstruction quality
vae_loss_weight: 1.0e-5 # KL divergence regularization
style_discriminator_loss_weight: 1.0 # Style-invariant content
content_discriminator_loss_weight: 1.0 # Content-invariant style
constraint_loss_weight: 0.1 # Consistency with pretrained encoders
Training Hyperparameters
learning_rate: 5.0e-5
max_epochs: 3
max_steps: 40000
global_batch_size: 32
effective_batch_size: 32 # With gradient accumulation
grad_norm_clip: 1.0
warmpup_proportion: 0.1
weight_decay: 0.0
π¬ Model Details
Disentanglement Objectives
Reconstruction Loss: Measures how well the generator reconstructs the input
L_recon = -log P(x | z_s, z_c)VAE KL Loss: Regularizes latent distributions
L_KL = KL(q(z|x) || p(z))Adversarial Discriminator Losses:
- Style discriminator tries to predict style from content latents (minimize for content encoder)
- Content discriminator tries to predict content from style latents (minimize for style encoder)
Constraint Loss: Maintains consistency with pretrained reference encoders
Mutual Information Regularization (optional): Further encourages independence
Training Strategy
- FSDP (Fully Sharded Data Parallel): Efficient distributed training
- Mixed Precision (BF16): Faster training with lower memory
- Gradient Checkpointing: Trade compute for memory
- Cyclic KL Annealing: Gradually increases KL weight for stable training
π License
This project is licensed under the MIT License - see the LICENSE file for details.
π€ Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
Development Setup
# Install development dependencies
pip install -r requirements.txt
# Run tests (if available)
pytest tests/
π Troubleshooting
Common Issues
OOM (Out of Memory)
- Reduce
global_batch_size - Enable
use_cpu_offload: true - Use gradient accumulation (
effective_batch_size > global_batch_size) - Enable activation checkpointing
- Reduce
FSDP Errors
- Ensure all model components are properly wrapped
- Check that
nodesanddevicesmatch your hardware - Try switching to DDP strategy for debugging
NaN Loss
- Reduce learning rate
- Increase warmup steps
- Check loss weight balancing
- Enable gradient clipping
π§ Contact
For questions or issues, please open an issue on GitHub or contact [[email protected]].
π Acknowledgments
- Built with PyTorch Lightning
- Uses models from Hugging Face Transformers
- Inspired by disentanglement research in VAEs and style transfer
Note: This is research code. For production use, additional testing and optimization may be required.
- Downloads last month
- 9


