Text Generation
Transformers
PyTorch
English
retnet
custom_code
Edit model card

Hybrid RetNet

This is a RetNet model, accompanying the paper Cross-Architecture Transfer Learning for Linear-Cost Inference Transformers, In this work, we proposed to not train new Linear-Cost Inference models (e.g. RetNet) from scratch, but to transfer shared weight components from other PTLMs. The model's input/output embeddings, MLP weights, Layer Norms, Attention Output Projections ($W_O$) has been transferred from pythia-410m. For more detail, please refer to the paper.

Model Details

Model Description

  • Developed by: NucleusAI, Sehyun Choi
  • Model type: RetNet & Transformer Hybrid

Model Sources

How to Get Started with the Model

Use the code below to get started with the model.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_default_device("cuda")

model = AutoModelForCausalLM.from_pretrained("NucleusAI/RetNet-410m-XATL", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("NucleusAI/RetNet-410m-XATL", trust_remote_code=True)  # same as EleutherAI/pythia-1B

inputs = tokenizer("Hi there!", return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)

Training Data

The model has been trained with pile_dedup dataset, in favor of comparison with the same sized pythia models.

Downloads last month
32
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.

Dataset used to train NucleusAI/RetNet-410m-XATL