Description
I'm not sure issues is the greatest place to post this but I just wanted to see if anyone else had been trying this idea:
There was a paper that came out recently that proposed a new head architecture, and I wanted to see if I could replicate the results (according to the paper they are very promising). It didn't seem too hard given what I knew from messing around with this repo. The authors provided 3 versions of the code here and to keep things simple I tried to use this implementation here. I added rotary positional encoding separately and tested that, it worked well, and then I added the differential mechanism, my code looks like this:
class CausalSelfAttention(nn.Module):
def __init__(self, config, depth):
super().__init__()
assert config.n_embd % config.n_head == 0
self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head // 2 # div by 2 because each head is larger, so we only have half as many
self.n_embd = config.n_embd
self.dropout = config.dropout
self.head_dim = self.n_embd // self.n_head // 2 # div by 2 because double key and query
self.rotary_emb = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=config.block_size) # Added line
self.lambda_init = lambda_init_fn(depth)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=False)
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
k = k.view(B, T, self.n_head*2, self.head_dim) # (B, T, nh, hs)
q = q.view(B, T, self.n_head*2, self.head_dim) # (B, T, nh, hs)
v = v.view(B, T, self.n_head, 2, self.head_dim) # (B, T, nh, hs)
# Apply rotary embeddings to q and k
cos, sin = self.rotary_emb(q, seq_len=T)
q = apply_rotary_pos_emb(q, cos, sin)
k = apply_rotary_pos_emb(k, cos, sin)
q = q.reshape(B, T, self.n_head, 2, self.head_dim)
k = k.reshape(B, T, self.n_head, 2, self.head_dim)
q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
attn11 = F.scaled_dot_product_attention(q1, k1, v1, attn_mask=None, is_causal=True)
attn12 = F.scaled_dot_product_attention(q1, k1, v2, attn_mask=None, is_causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = F.scaled_dot_product_attention(q2, k2, v1, attn_mask=None, is_causal=True)
attn22 = F.scaled_dot_product_attention(q2, k2, v2, attn_mask=None, is_causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn = attn1 - lambda_full * attn2
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.reshape(B, T, C)
# output projection
y = self.resid_dropout(self.c_proj(attn))
return y
When i try and train this model it understandably trains at a lower iterations/sec, but if we look at the loss per iteration it seems to be getting stuck. (in each iteration i have kept the total batch size as compared to the gpt2-124M-RoPE run)
Any ideas on what I've gotten wrong? I'm no ML expert
@karpathy on the off chance that you see this, have you read about the diff transformer paper and if so, what do you think about it?
Activity