Open
Description
- https://github.com/tesla-cat/nanoDeepSeek/blob/main/compare_nanoGPT.py
- https://github.com/tesla-cat/nanoDeepSeek/tree/main/src/nanoGPT
model
import math
from dataclasses import dataclass
from typing import Dict
import torch as tc
import torch.nn as nn
from torch.nn import functional as F
TS_DICT = Dict[str, tc.Tensor]
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 12
n_embed: int = 768
dropout: float = 0.0
bias: bool = True
class Attention(nn.Module):
def __init__(s, c: GPTConfig):
super().__init__()
s.conf = c
assert c.n_embed % c.n_head == 0
E = c.n_embed
s.c_attn = nn.Linear(E, 3 * E, c.bias)
s.c_proj = nn.Linear(E, E, c.bias)
s.resid_dropout = nn.Dropout(c.dropout)
def forward(s, x: tc.Tensor):
B, T, E = x.shape
H = s.conf.n_head
y1: tc.Tensor = s.c_attn(x)
q, k, v = [z.view(B, T, H, E // H).transpose(1, 2) for z in y1.split(E, dim=2)]
drop = s.conf.dropout if s.training else 0
y2 = F.scaled_dot_product_attention(q, k, v, dropout_p=drop, is_causal=True)
y3 = y2.transpose(1, 2).contiguous().view(B, T, E)
return s.resid_dropout(s.c_proj(y3))
class MLP(nn.Module):
def __init__(s, c: GPTConfig):
super().__init__()
E = c.n_embed
s.c_fc = nn.Linear(E, 4 * E, c.bias)
s.gelu = nn.GELU()
s.c_proj = nn.Linear(4 * E, E, c.bias)
s.dropout = nn.Dropout(c.dropout)
def forward(s, x):
return s.dropout(s.c_proj(s.gelu(s.c_fc(x))))
class TransLayer(nn.Module):
def __init__(s, c: GPTConfig):
super().__init__()
s.ln_1 = nn.LayerNorm(c.n_embed, bias=c.bias)
s.attn = Attention(c)
s.ln_2 = nn.LayerNorm(c.n_embed, bias=c.bias)
s.mlp = MLP(c)
def forward(s, x):
x = x + s.attn(s.ln_1(x))
return x + s.mlp(s.ln_2(x))
class GPT(nn.Module):
def __init__(s, c: GPTConfig):
super().__init__()
s.conf = c
E = c.n_embed
s.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(c.vocab_size, E),
wpe=nn.Embedding(c.block_size, E),
drop=nn.Dropout(c.dropout),
h=nn.ModuleList([TransLayer(c) for _ in range(c.n_layer)]),
ln_f=nn.LayerNorm(E, bias=c.bias),
)
)
s.lm_head = nn.Linear(E, c.vocab_size, bias=False)
s.transformer.wte.weight = s.lm_head.weight # weight-tying
s.apply(s._init_weights)
for k, p in s.named_parameters():
if k.endswith("c_proj.weight"):
nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * c.n_layer))
print(f"n_params: {s.n_params() / 1e6:.2f}M")
def n_params(s):
return sum(p.numel() for p in s.parameters())
def _init_weights(s, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(s, x: tc.Tensor, y0: tc.Tensor = None):
B, T = x.shape
c = s.conf
assert T <= c.block_size
tok = s.transformer.wte(x)
pos = s.transformer.wpe(tc.arange(0, T, dtype=tc.long, device=x.device))
x = s.transformer.drop(tok + pos)
for layer in s.transformer.h:
x = layer(x)
x = s.transformer.ln_f(x)
if y0 is not None:
logits: tc.Tensor = s.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)), y0.view(-1), ignore_index=-1
)
else:
logits = s.lm_head(x[:, [-1], :])
loss = None
return logits, loss
def crop_block_size(s, n):
assert n <= s.conf.block_size
s.conf.block_size = n
wpe = s.transformer.wpe
wpe.weight = nn.Parameter(wpe.weight[:n])
for layer in s.transformer.h:
if hasattr(layer.attn, "bias"):
layer.attn.bias = layer.attn.bias[:, :, :n, :n]
@classmethod
@tc.no_grad()
def from_pretrained(s, type, drop=0.0):
assert type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"}
from transformers import GPT2LMHeadModel
args = {
"gpt2": dict(n_layer=12, n_head=12, n_embed=768), # 124M params
"gpt2-medium": dict(n_layer=24, n_head=16, n_embed=1024), # 350M params
"gpt2-large": dict(n_layer=36, n_head=20, n_embed=1280), # 774M params
"gpt2-xl": dict(n_layer=48, n_head=25, n_embed=1600), # 1558M params
}[type]
args.update(dict(vocab_size=50257, block_size=1024, bias=True, dropout=drop))
m = GPT(GPTConfig(**args))
sd: TS_DICT = m.state_dict()
ignore = (".attn.bias", ".attn.masked_bias")
keys = [k for k in sd if not k.endswith(ignore)]
m2 = GPT2LMHeadModel.from_pretrained(type)
sd2: TS_DICT = m2.state_dict()
keys2 = [k for k in sd2 if not k.endswith(ignore)]
assert len(keys) == len(keys2)
trans = [
"attn.c_attn.weight",
"attn.c_proj.weight",
"mlp.c_fc.weight",
"mlp.c_proj.weight",
]
for k in keys2:
x = sd2[k].t() if any(k.endswith(w) for w in trans) else sd2[k]
sd[k].copy_(x)
return m
@tc.no_grad()
def generate(s, x, num, temp=1.0):
B = s.conf.block_size
for _ in range(num):
x = x if x.size(1) <= B else x[:, -B:]
logits: tc.Tensor = s(x)[0]
probs = F.softmax(logits[:, -1, :] / temp, dim=-1)
x2 = tc.multinomial(probs, num_samples=1)
x = tc.cat((x, x2), dim=1)
return x
trainer (not fully tested -- I don't have GPU)
import os
from contextlib import nullcontext
import numpy as np
import torch as tc
import torch.nn as nn
import wandb
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
class LLMTrainer:
# ======== setup ====================
backend = "nccl"
grad_acc_steps = 40
w_decay = 0.1
lr = 6e-4
betas = [0.9, 0.95]
fused = True
compile = True
save_path = "llm_cp"
log_id = None
# ======== get_batch =================
data_path = "llm_data"
batch_size = 12
block_size = 1024
# ======= get_cos_lr ===============
i_warmup = 2000
# ========== train =================
i_eval = 2000
n_eval = 200
grad_clip = 1.0
i_max = 600000
def setup(s, m: nn.Module):
cuda_ok = tc.cuda.is_available()
bf16_ok = tc.cuda.is_bf16_supported()
device = "cuda" if cuda_ok else "cpu"
dtype = "bfloat16" if cuda_ok and bf16_ok else "float16"
rank = int(os.environ.get("RANK", -1))
s.is_master = rank in [-1, 0]
if rank != -1:
init_process_group(s.backend)
local_rank = int(os.environ["LOCAL_RANK"])
device = f"cuda:{local_rank}"
tc.cuda.set_device(device)
n_proc = int(os.environ["WORLD_SIZE"])
assert s.grad_acc_steps % n_proc == 0
s.grad_acc_steps //= n_proc
tc.manual_seed(1338 + rank)
tc.backends.cuda.matmul.allow_tf32 = True
tc.backends.cudnn.allow_tf32 = True
s.ctx = (
nullcontext()
if device == "cpu"
else tc.amp.autocast("cuda", getattr(tc, dtype))
)
s.scaler = tc.amp.grad_scaler.GradScaler(device, enabled=dtype == "float16")
# ============================================
g1 = [p for p in m.parameters() if p.dim() >= 2]
g2 = [p for p in m.parameters() if p.dim() < 2]
params = [
{"params": g1, "weight_decay": s.w_decay},
{"params": g2, "weight_decay": 0.0},
]
s.opt = tc.optim.AdamW(params, lr=s.lr, betas=s.betas, fused=s.fused)
if os.path.exists(s.save_path):
r = tc.load(s.save_path)
m.load_state_dict(r["model"])
s.opt.load_state_dict(r["opt"])
s.iter, s.best_loss = r["iter"], r["best_loss"]
else:
s.iter, s.best_loss = 0, np.inf
# ============================================
m.to(device)
if rank != -1:
m = DDP(m, device_ids=[local_rank])
if s.compile:
print("compiling")
m.compile()
if s.log_id:
wandb.init(project="LLMTrainer", name=s.log_id)
s.device = device
s.model = m
def get_batch(s, split="train"):
B, T = s.batch_size, s.block_size
data = np.memmap(s.data_path, np.uint16, mode="r").astype(np.int64)
a = int(len(data) * 0.9)
data = data[:a] if split == "train" else data[a:]
idx = tc.randint(len(data) - T, (B,))
x = tc.stack([tc.from_numpy(data[i : i + T]) for i in idx])
y = tc.stack([tc.from_numpy(data[i + 1 : i + 1 + T]) for i in idx])
if s.device != "cpu":
x = x.pin_memory().to(s.device, non_blocking=True)
y = y.pin_memory().to(s.device, non_blocking=True)
return x, y
@tc.no_grad()
def get_losses(s):
m = s.model
m.eval()
res = {}
for split in ["train", "val"]:
losses = tc.zeros(s.n_eval)
for k in range(s.n_eval):
x, y = s.get_batch(split)
with s.ctx:
loss: tc.Tensor = m(x, y)[1]
losses[k] = loss.item()
res[split] = losses.mean()
m.train()
return res
def get_cos_lr(s, i):
min_lr = s.lr * 0.1
if i < s.i_warmup:
return s.lr * (i + 1) / (s.i_warmup + 1)
if i > s.i_max:
return min_lr
r = (i - s.i_warmup) / (s.i_max - s.i_warmup)
assert 0 <= r <= 1
c = 0.5 * (1.0 + np.cos(np.pi * r))
return min_lr + c * (s.lr - min_lr)
def train(s):
x, y = s.get_batch()
m = s.model
is_ddp = isinstance(m, DDP)
# raw_m = m.module if is_ddp else m
while s.iter < s.i_max:
lr = s.get_cos_lr(s.iter)
for p in s.opt.param_groups:
p["lr"] = lr
if s.iter % s.i_eval == 0 and s.is_master:
losses = s.get_losses()
print(f"{s.iter} {losses}")
if s.log_id:
wandb.log(
{
"iter": s.iter,
"train/loss": losses["train"],
"val/loss": losses["val"],
"lr": lr,
}
)
if losses["val"] < s.best_loss:
s.best_loss = losses["val"]
if s.iter:
obj = {
"model": m.state_dict(),
"opt": s.opt.state_dict(),
"iter": s.iter,
"best_loss": s.best_loss,
}
tc.save(obj, s.save_path)
for j in range(s.grad_acc_steps):
if is_ddp:
m.require_backward_grad_sync = j == s.grad_acc_steps - 1
with s.ctx:
logits, loss = m(x, y)
loss /= s.grad_acc_steps
x, y = s.get_batch()
s.scaler.scale(loss).backward()
if s.grad_clip:
s.scaler.unscale_(s.opt)
nn.utils.clip_grad_norm_(m.parameters(), s.grad_clip)
s.scaler.step(s.opt)
s.scaler.update()
s.opt.zero_grad()
s.iter += 1
if is_ddp:
destroy_process_group()
Metadata
Assignees
Labels
No labels
Activity