from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt
# Mistral 베이스 모델과 Instruct 모델 로드
base_model_name = "mistralai/Mistral-7B-v0.1"
chat_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
chat_model = AutoModelForCausalLM.from_pretrained(chat_model_name)
def calculate_weight_diff(base_model, chat_model):
diffs = []
for (base_name, base_param), (chat_name, chat_param) in zip(base_model.named_parameters(), chat_model.named_parameters()):
if base_param.requires_grad:
diff = torch.abs(base_param.data - chat_param.data)
diffs.append(diff.mean().item())
return diffs
def calculate_layer_diffs(base_model, chat_model):
layer_diffs = []
for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers):
layer_diff = calculate_weight_diff(base_layer, chat_layer)
layer_diffs.append(sum(layer_diff) / len(layer_diff))
return layer_diffs
def visualize_diffs(diffs):
plt.figure(figsize=(12, 6))
plt.bar(range(len(diffs)), diffs)
plt.xlabel("Layer")
plt.ylabel("Average Weight Difference")
plt.title("Weight Difference between Base and Instruct Models")
plt.xticks(range(len(diffs)), range(len(diffs)))
plt.yscale('log')
plt.show()
# 레이어별 가중치 차이 계산
layer_diffs = calculate_layer_diffs(base_model, chat_model)
# 가중치 차이 시각화
visualize_diffs(layer_diffs)
각 레이어 요소별 시각화
def calculate_weight_diff(base_weight, chat_weight):
return torch.abs(base_weight - chat_weight).mean().item()
def calculate_layer_diffs(base_model, chat_model):
layer_diffs = []
for base_layer, chat_layer in zip(base_model.model.layers, chat_model.model.layers):
layer_diff = {
'input_layernorm': calculate_weight_diff(base_layer.input_layernorm.weight, chat_layer.input_layernorm.weight),
'mlp_down_proj': calculate_weight_diff(base_layer.mlp.down_proj.weight, chat_layer.mlp.down_proj.weight),
'mlp_gate_proj': calculate_weight_diff(base_layer.mlp.gate_proj.weight, chat_layer.mlp.gate_proj.weight),
'mlp_up_proj': calculate_weight_diff(base_layer.mlp.up_proj.weight, chat_layer.mlp.up_proj.weight),
'post_attention_layernorm': calculate_weight_diff(base_layer.post_attention_layernorm.weight, chat_layer.post_attention_layernorm.weight),
'self_attn_q_proj': calculate_weight_diff(base_layer.self_attn.q_proj.weight, chat_layer.self_attn.q_proj.weight),
'self_attn_k_proj': calculate_weight_diff(base_layer.self_attn.k_proj.weight, chat_layer.self_attn.k_proj.weight),
'self_attn_v_proj': calculate_weight_diff(base_layer.self_attn.v_proj.weight, chat_layer.self_attn.v_proj.weight),
'self_attn_o_proj': calculate_weight_diff(base_layer.self_attn.o_proj.weight, chat_layer.self_attn.o_proj.weight)
}
layer_diffs.append(layer_diff)
return layer_diffs
def visualize_layer_diffs(layer_diffs):
num_layers = len(layer_diffs)
num_components = len(layer_diffs[0])
fig, axs = plt.subplots(1, num_components, figsize=(24, 8))
fig.suptitle(f"{base_model_name} <> {chat_model_name}", fontsize=16)
for i, component in enumerate(layer_diffs[0].keys()):
component_diffs = [[layer_diff[component]] for layer_diff in layer_diffs]
sns.heatmap(component_diffs, annot=True, fmt=".6f", cmap="YlGnBu", ax=axs[i], cbar_kws={"shrink": 0.8})
axs[i].set_title(component)
axs[i].set_xlabel("Layer")
axs[i].set_ylabel("Difference")
axs[i].set_xticks([])
axs[i].set_yticks(range(num_layers))
axs[i].set_yticklabels(range(num_layers))
axs[i].invert_yaxis()
plt.tight_layout()
plt.show()
# 레이어별 가중치 차이 계산
layer_diffs = calculate_layer_diffs(base_model, chat_model)
# 가중치 차이 히트맵으로 시각화
visualize_layer_diffs(layer_diffs)