Skip to content

How best to implement a differential transformer? #567

Open
@Wilsontomass

Description

I'm not sure issues is the greatest place to post this but I just wanted to see if anyone else had been trying this idea:

There was a paper that came out recently that proposed a new head architecture, and I wanted to see if I could replicate the results (according to the paper they are very promising). It didn't seem too hard given what I knew from messing around with this repo. The authors provided 3 versions of the code here and to keep things simple I tried to use this implementation here. I added rotary positional encoding separately and tested that, it worked well, and then I added the differential mechanism, my code looks like this:

class CausalSelfAttention(nn.Module):
    def __init__(self, config, depth):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head // 2  # div by 2 because each head is larger, so we only have half as many
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        self.head_dim = self.n_embd // self.n_head // 2 # div by 2 because double key and query
        self.rotary_emb = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=config.block_size)  # Added line
        
        self.lambda_init = lambda_init_fn(depth)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=False)
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        k = k.view(B, T, self.n_head*2, self.head_dim)  # (B, T, nh, hs)
        q = q.view(B, T, self.n_head*2, self.head_dim)  # (B, T, nh, hs)
        v = v.view(B, T, self.n_head, 2, self.head_dim)  # (B, T, nh, hs)
        # Apply rotary embeddings to q and k
        cos, sin = self.rotary_emb(q, seq_len=T)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        q = q.reshape(B, T, self.n_head, 2, self.head_dim)
        k = k.reshape(B, T, self.n_head, 2, self.head_dim)
        q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
        k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
        v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
        attn11 = F.scaled_dot_product_attention(q1, k1, v1, attn_mask=None, is_causal=True)
        attn12 = F.scaled_dot_product_attention(q1, k1, v2, attn_mask=None, is_causal=True)
        attn1 = torch.cat([attn11, attn12], dim=-1)
        attn21 = F.scaled_dot_product_attention(q2, k2, v1, attn_mask=None, is_causal=True)
        attn22 = F.scaled_dot_product_attention(q2, k2, v2, attn_mask=None, is_causal=True)
        attn2 = torch.cat([attn21, attn22], dim=-1)
        
        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        attn = attn1 - lambda_full * attn2
        attn = self.subln(attn)
        attn = attn * (1 - self.lambda_init)
        attn = attn.reshape(B, T, C)
        # output projection
        y = self.resid_dropout(self.c_proj(attn))
        return y

When i try and train this model it understandably trains at a lower iterations/sec, but if we look at the loss per iteration it seems to be getting stuck. (in each iteration i have kept the total batch size as compared to the gpt2-124M-RoPE run)
image
Any ideas on what I've gotten wrong? I'm no ML expert

@karpathy on the off chance that you see this, have you read about the diff transformer paper and if so, what do you think about it?

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions