【AIGC】大模型面试高频考点-位置编码篇
- (一)手撕 绝对位置编码 算法
- (二)手撕 可学习位置编码 算法
- (三)手撕 相对位置编码 算法
- (四)手撕 Rope 算法(旋转位置编码)
(一)手撕 绝对位置编码 算法
class SinPositionEncoding(nn.Module):def __init__(self, max_sequence_length, d_model, base=10000):super().__init__()self.max_sequence_length = max_sequence_lengthself.d_model = d_modelself.base = basedef forward(self):pe = torch.zeros(self.max_sequence_length, self.d_model, dtype=torch.float) exp_1 = torch.arange(self.d_model // 2, dtype=torch.float) exp_value = exp_1 / (self.d_model / 2)alpha = 1 / (self.base ** exp_value) out = torch.arange(self.max_sequence_length, dtype=torch.float)[:, None] @ alpha[None, :] embedding_sin = torch.sin(out)embedding_cos = torch.cos(out)pe[:, 0::2] = embedding_sin pe[:, 1::2] = embedding_cos return peSinPositionEncoding(d_model=4, max_sequence_length=10, base=10000).forward()
(二)手撕 可学习位置编码 算法
class TrainablePositionEncoding(nn.Module):def __init__(self, max_sequence_length, d_model):super().__init__()self.max_sequence_length = max_sequence_lengthself.d_model = d_modeldef forward(self):pe = nn.Embedding(self.max_sequence_length, self.d_model)nn.init.constant(pe.weight, 0.)return pe
(三)手撕 相对位置编码 算法
class RelativePosition(nn.Module):def __init__(self, num_units, max_relative_position):super().__init__()self.num_units = num_unitsself.max_relative_position = max_relative_positionself.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))nn.init.xavier_uniform_(self.embeddings_table)def forward(self, length_q, length_k):range_vec_q = torch.arange(length_q)range_vec_k = torch.arange(length_k)distance_mat = range_vec_k[None, :] - range_vec_q[:, None]distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionfinal_mat = torch.LongTensor(final_mat).cuda()embeddings = self.embeddings_table[final_mat].cuda()return embeddingsclass RelativeMultiHeadAttention(nn.Module):def __init__(self, d_model, n_heads, dropout=0.1, batch_size=6):"Take in model size and number of heads."super(RelativeMultiHeadAttention, self).__init__()self.d_model = d_modelself.n_heads = n_headsself.batch_size = batch_sizeassert d_model % n_heads == 0self.head_dim = d_model // n_headsself.linears = _get_clones(nn.Linear(d_model, d_model), 4)self.dropout = nn.Dropout(p=dropout)self.relative_position_k = RelativePosition(self.head_dim, max_relative_position=16)self.relative_position_v = RelativePosition(self.head_dim, max_relative_position=16)self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).cuda()def forward(self, query, key, value):query, key, value = [l(x).view(self.batch_size, -1, self.d_model) for l, x inzip(self.linears, (query, key, value))]len_k = query.shape[1]len_q = query.shape[1]len_v = value.shape[1]r_q1 = query.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)r_k1 = key.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, self.batch_size * self.n_heads, self.head_dim)r_k2 = self.relative_position_k(len_q, len_k)attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)attn2 = attn2.contiguous().view(self.batch_size, self.n_heads, len_q, len_k)attn = (attn1 + attn2) / self.scaleattn = self.dropout(torch.softmax(attn, dim=-1))r_v1 = value.view(self.batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)weight1 = torch.matmul(attn, r_v1)r_v2 = self.relative_position_v(len_q, len_v)weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, self.batch_size * self.n_heads, len_k)weight2 = torch.matmul(weight2, r_v2)weight2 = weight2.transpose(0, 1).contiguous().view(self.batch_size, self.n_heads, len_q, self.head_dim)x = weight1 + weight2x = x.permute(0, 2, 1, 3).contiguous()x = x.view(self.batch_size * len_q, self.d_model)return self.linears[-1](x)
(四)手撕 Rope 算法(旋转位置编码)
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathdef sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)ids = torch.arange(0, output_dim // 2, dtype=torch.float) theta = torch.pow(10000, -2 * ids / output_dim)embeddings = position * theta embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))embeddings = embeddings.to(device)return embeddings
def RoPE(q, k):batch_size = q.shape[0]nums_head = q.shape[1]max_len = q.shape[2]output_dim = q.shape[-1]pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)q2 = q2.reshape(q.shape) q = q * cos_pos + q2 * sin_posk2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)k2 = k2.reshape(k.shape)k = k * cos_pos + k2 * sin_posreturn q, k
def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):if use_RoPE:q, k = RoPE(q, k)d_k = k.size()[-1]att_logits = torch.matmul(q, k.transpose(-2, -1)) att_logits /= math.sqrt(d_k)if mask is not None:att_logits = att_logits.masked_fill(mask == 0, -1e9) att_scores = F.softmax(att_logits, dim=-1) if dropout is not None:att_scores = dropout(att_scores)return torch.matmul(att_scores, v), att_scoresif __name__ == '__main__':q = torch.randn((8, 12, 10, 32))k = torch.randn((8, 12, 10, 32))v = torch.randn((8, 12, 10, 32))res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)print(res.shape, att_scores.shape)