| import numpy as np
|
|
|
|
|
| batch_size = 4
|
| d_model = 512
|
| d_k = 64
|
| d_ff = 2048
|
| vocab_size = 10000
|
| enc_seq_len = 10
|
| num_heads = 8
|
|
|
|
|
| input_data = np.random.randn(batch_size, enc_seq_len, d_model) * 0.1
|
|
|
|
|
|
|
| def init_weights(shape):
|
| """He/Xavier ์ด๊ธฐํ์ ๊ฐ๋ตํ ๋ฒ์ """
|
| if len(shape) == 1:
|
| return np.zeros(shape)
|
|
|
| return np.random.randn(*shape) * np.sqrt(1.0 / shape[0])
|
|
|
|
|
|
|
| def layer_normalization(x, gamma, beta, epsilon=1e-5):
|
| """Layer Normalization (๊ณ์ธต ์ ๊ทํ)"""
|
|
|
| mean = np.mean(x, axis=-1, keepdims=True)
|
| variance = np.mean((x - mean) ** 2, axis=-1, keepdims=True)
|
| x_normalized = (x - mean) / np.sqrt(variance + epsilon)
|
| output = gamma * x_normalized + beta
|
| return output
|
|
|
| def scaled_dot_product_attention(Q, K, V, mask=None):
|
| """Scaled Dot-Product Attention (๋ฐฐ์น ์ฒ๋ฆฌ ์ง์)"""
|
|
|
| scores = np.matmul(Q, K.transpose(0, 1, 3, 2))
|
| scores = scores / np.sqrt(d_k)
|
|
|
| if mask is not None:
|
| scores = scores + mask
|
|
|
| exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
|
| attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)
|
|
|
| output = np.matmul(attention_weights, V)
|
| return output, attention_weights
|
|
|
| def multi_head_attention(Q, K, V, W_Q, W_K, W_V, W_O, mask=None):
|
| """
|
| Multi-Head Attention (์ค๋ฅ ์์ : ๋์ ์ํ์ค ๊ธธ์ด ์ฒ๋ฆฌ)
|
| Q: [B, S_q, D], K: [B, S_k, D], V: [B, S_k, D]
|
| """
|
|
|
|
|
| B_q, S_q, D_q = Q.shape
|
| B_k, S_k, D_k = K.shape
|
| B_v, S_v, D_v = V.shape
|
|
|
|
|
|
|
|
|
| Q_proj = np.matmul(Q, W_Q)
|
| K_proj = np.matmul(K, W_K)
|
| V_proj = np.matmul(V, W_V)
|
|
|
|
|
|
|
| Q_multi = Q_proj.reshape(B_q, S_q, num_heads, d_k).transpose(0, 2, 1, 3)
|
|
|
| K_multi = K_proj.reshape(B_k, S_k, num_heads, d_k).transpose(0, 2, 1, 3)
|
|
|
| V_multi = V_proj.reshape(B_v, S_v, num_heads, d_k).transpose(0, 2, 1, 3)
|
|
|
|
|
| attended_output, _ = scaled_dot_product_attention(Q_multi, K_multi, V_multi, mask)
|
|
|
|
|
| attended_output = attended_output.transpose(0, 2, 1, 3).reshape(B_q, S_q, d_model)
|
|
|
|
|
| output = np.matmul(attended_output, W_O)
|
| return output
|
|
|
| def feed_forward_network(x, W1, b1, W2, b2):
|
| """Feed-Forward Network (FFN)"""
|
| hidden = np.matmul(x, W1) + b1
|
| hidden = np.maximum(0, hidden)
|
| output = np.matmul(hidden, W2) + b2
|
| return output
|
|
|
|
|
|
|
|
|
| W_Q_enc, W_K_enc, W_V_enc, W_O_enc = init_weights((d_model, d_model)), init_weights((d_model, d_model)), init_weights((d_model, d_model)), init_weights((d_model, d_model))
|
| W1_enc, W2_enc = init_weights((d_model, d_ff)), init_weights((d_ff, d_model))
|
| b1_enc, b2_enc = init_weights((1, d_ff)), init_weights((1, d_model))
|
| gamma_enc1, beta_enc1 = np.ones((1, 1, d_model)), np.zeros((1, 1, d_model))
|
| gamma_enc2, beta_enc2 = np.ones((1, 1, d_model)), np.zeros((1, 1, d_model))
|
|
|
|
|
| W_Q_dec_self, W_K_dec_self, W_V_dec_self, W_O_dec_self = init_weights((d_model, d_model)), init_weights((d_model, d_model)), init_weights((d_model, d_model)), init_weights((d_model, d_model))
|
| W_Q_dec_cross, W_K_dec_cross, W_V_dec_cross, W_O_dec_cross = init_weights((d_model, d_model)), init_weights((d_model, d_model)), init_weights((d_model, d_model)), init_weights((d_model, d_model))
|
| W1_dec, W2_dec = init_weights((d_model, d_ff)), init_weights((d_ff, d_model))
|
| b1_dec, b2_dec = init_weights((1, d_ff)), init_weights((1, d_model))
|
| gamma_dec1, beta_dec1 = np.ones((1, 1, d_model)), np.zeros((1, 1, d_model))
|
| gamma_dec2, beta_dec2 = np.ones((1, 1, d_model)), np.zeros((1, 1, d_model))
|
| gamma_dec3, beta_dec3 = np.ones((1, 1, d_model)), np.zeros((1, 1, d_model))
|
|
|
|
|
|
|
|
|
| def encoder_block(x):
|
|
|
|
|
|
|
| attn_output = multi_head_attention(x, x, x, W_Q_enc, W_K_enc, W_V_enc, W_O_enc)
|
|
|
|
|
| x_1 = layer_normalization(attn_output + x, gamma_enc1, beta_enc1)
|
|
|
|
|
| ffn_output = feed_forward_network(x_1, W1_enc, b1_enc, W2_enc, b2_enc)
|
|
|
|
|
| output = layer_normalization(ffn_output + x_1, gamma_enc2, beta_enc2)
|
|
|
| return output
|
|
|
|
|
|
|
| def create_look_ahead_mask(size):
|
| """Look-ahead Mask ์์ฑ (๋ฏธ๋ ๋จ์ด ๋ง์คํน)"""
|
| mask = np.triu(np.ones((size, size)), k=1)
|
| return (mask * -1e9)[np.newaxis, np.newaxis, :, :]
|
|
|
| def decoder_block(x, enc_output, look_ahead_mask):
|
|
|
|
|
|
|
| self_attn_output = multi_head_attention(
|
| x, x, x, W_Q_dec_self, W_K_dec_self, W_V_dec_self, W_O_dec_self, mask=look_ahead_mask
|
| )
|
|
|
|
|
| x_1 = layer_normalization(self_attn_output + x, gamma_dec1, beta_dec1)
|
|
|
|
|
|
|
| cross_attn_output = multi_head_attention(
|
| x_1, enc_output, enc_output, W_Q_dec_cross, W_K_dec_cross, W_V_dec_cross, W_O_dec_cross, mask=None
|
| )
|
|
|
|
|
| x_2 = layer_normalization(cross_attn_output + x_1, gamma_dec2, beta_dec2)
|
|
|
|
|
| ffn_output = feed_forward_network(x_2, W1_dec, b1_dec, W2_dec, b2_dec)
|
|
|
|
|
| output = layer_normalization(ffn_output + x_2, gamma_dec3, beta_dec3)
|
|
|
| return output
|
|
|
|
|
|
|
| W_linear = init_weights((d_model, vocab_size))
|
| b_linear = init_weights((1, vocab_size))
|
|
|
| def final_output_layer(x):
|
|
|
| logits = np.matmul(x, W_linear) + b_linear
|
|
|
| exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
|
| probabilities = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
|
|
|
| return probabilities
|
|
|
|
|
|
|
| print("--- Add & Norm ์ ์ฉ๋ ํธ๋์คํฌ๋จธ ์๋ฎฌ๋ ์ด์
์์ ---")
|
|
|
|
|
|
|
| enc_output_final = encoder_block(input_data)
|
| print(f"์ธ์ฝ๋ ์ต์ข
์ถ๋ ฅ ํํ (K, V ์์ค): {enc_output_final.shape}")
|
|
|
|
|
| dec_seq_len = 5
|
| decoder_input_data = np.random.randn(batch_size, dec_seq_len, d_model) * 0.1
|
| look_ahead_mask = create_look_ahead_mask(dec_seq_len)
|
|
|
|
|
|
|
|
|
|
|
| dec_output_final = decoder_block(decoder_input_data, enc_output_final, look_ahead_mask)
|
| print(f"๋์ฝ๋ ์ต์ข
์ถ๋ ฅ ํํ: {dec_output_final.shape}")
|
|
|
|
|
| probabilities = final_output_layer(dec_output_final)
|
| print(f"์ต์ข
ํ๋ฅ ๋ถํฌ ํํ (B x S_target x V): {probabilities.shape}")
|
|
|
| print("\n**์๋ฃ**") |