知識蒸留(Knowledge Distillation)

大規模で高性能な教師モデルの知識を小規模な生徒モデルに転移する機械学習技術。ソフトターゲットによる学習により、コンパクトなモデルでも教師モデルに近い性能を実現し、実用的なAIシステムの構築を可能にする

知識蒸留とは

知識蒸留(Knowledge Distillation)は、大規模で高性能な教師モデル(Teacher Model)が学習した知識を、より小規模で効率的な生徒モデル(Student Model)に転移する機械学習技術です。従来のハードターゲット(正解ラベル)ではなく、教師モデルの出力確率分布(ソフトターゲット)を用いて生徒モデルを訓練することで、教師モデルの知識や推論パターンを効果的に継承します。これにより、大幅に小さなモデルでも教師モデルに近い性能を実現し、エッジデバイスや実用環境での高性能AI推論を可能にする重要な技術です。

背景と重要性

現代の高性能AIモデルは優れた結果を示しますが、膨大なパラメータ数と計算量により、実用的な展開が困難な場合があります。直接的なモデル圧縮では性能劣化が避けられないことが多く、より賢い知識継承メカニズムが必要でした。

知識蒸留は、

  • 高性能モデルの知識保持
  • 大幅なサイズ削減の実現
  • 実用環境での高速推論

を両立することで、AI技術の実用化と民主化を促進します。教師モデルの「暗黙知」も含めて転移することで、効率的で実用的なAIシステムの構築が可能になります。

主な構成要素

教師モデル(Teacher Model)

知識の源となる大規模で高性能なモデルです。

生徒モデル(Student Model)

知識を継承する小規模で効率的なモデルです。

ソフトターゲット(Soft Targets)

教師モデルの出力確率分布による指導信号です。

温度パラメータ(Temperature)

確率分布の平滑化により知識転移を制御する重要なパラメータです。

蒸留損失(Distillation Loss)

教師と生徒の出力を近づけるための損失関数です。

特徴量蒸留(Feature Distillation)

中間層の特徴量レベルでの知識転移手法です。

主な特徴

知識保持性

元モデルの重要な知識を効率的に継承できます。

汎化性

様々なモデルアーキテクチャ間での知識転移が可能です。

拡張性

複数の蒸留手法を組み合わせて使用できます。

知識蒸留の基本原理

基本的な知識蒸留

温度パラメータによるソフトマックス:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.7):
        super(KnowledgeDistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.hard_loss = nn.CrossEntropyLoss()
        self.soft_loss = nn.KLDivLoss(reduction='batchmean')
    
    def forward(self, student_outputs, teacher_outputs, targets):
        """
        知識蒸留損失の計算
        
        Args:
            student_outputs: 生徒モデルの出力 [batch_size, num_classes]
            teacher_outputs: 教師モデルの出力 [batch_size, num_classes]
            targets: 正解ラベル [batch_size]
        """
        
        # ハード損失(正解ラベルとの交差エントロピー)
        hard_loss = self.hard_loss(student_outputs, targets)
        
        # ソフト損失(教師との知識蒸留損失)
        teacher_soft = F.softmax(teacher_outputs / self.temperature, dim=1)
        student_soft = F.log_softmax(student_outputs / self.temperature, dim=1)
        soft_loss = self.soft_loss(student_soft, teacher_soft) * (self.temperature ** 2)
        
        # 重み付き組み合わせ
        total_loss = self.alpha * soft_loss + (1 - self.alpha) * hard_loss
        
        return total_loss, hard_loss, soft_loss

class BasicKnowledgeDistiller:
    def __init__(self, teacher_model, student_model, device='cpu'):
        self.teacher_model = teacher_model.to(device)
        self.student_model = student_model.to(device)
        self.device = device
        
        # 教師モデルを評価モードに固定
        self.teacher_model.eval()
        for param in self.teacher_model.parameters():
            param.requires_grad = False
    
    def train_student(self, train_loader, val_loader=None, num_epochs=10, 
                     temperature=4.0, alpha=0.7, learning_rate=0.001):
        """基本的な知識蒸留による生徒モデルの訓練"""
        
        print(f"=== 知識蒸留開始 ===")
        print(f"温度パラメータ: {temperature}")
        print(f"蒸留重み (α): {alpha}")
        print(f"学習率: {learning_rate}")
        
        # 損失関数とオプティマイザー
        criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
        optimizer = torch.optim.Adam(self.student_model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        
        training_history = {
            'epoch': [],
            'train_loss': [],
            'hard_loss': [],
            'soft_loss': [],
            'val_accuracy': []
        }
        
        for epoch in range(num_epochs):
            # 訓練フェーズ
            self.student_model.train()
            epoch_loss = 0
            epoch_hard_loss = 0
            epoch_soft_loss = 0
            num_batches = 0
            
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                # 教師モデルの出力(勾配計算なし)
                with torch.no_grad():
                    teacher_outputs = self.teacher_model(data)
                
                # 生徒モデルの出力
                student_outputs = self.student_model(data)
                
                # 損失計算
                total_loss, hard_loss, soft_loss = criterion(
                    student_outputs, teacher_outputs, targets
                )
                
                # 逆伝播
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                
                # 統計更新
                epoch_loss += total_loss.item()
                epoch_hard_loss += hard_loss.item()
                epoch_soft_loss += soft_loss.item()
                num_batches += 1
                
                if batch_idx % 100 == 0:
                    print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}: '
                          f'Total Loss: {total_loss.item():.4f}, '
                          f'Hard: {hard_loss.item():.4f}, '
                          f'Soft: {soft_loss.item():.4f}')
            
            # エポック平均損失
            avg_loss = epoch_loss / num_batches
            avg_hard_loss = epoch_hard_loss / num_batches
            avg_soft_loss = epoch_soft_loss / num_batches
            
            # 検証フェーズ
            val_accuracy = 0
            if val_loader is not None:
                val_accuracy = self.evaluate_student(val_loader)
            
            # 学習率調整
            scheduler.step()
            
            # 履歴記録
            training_history['epoch'].append(epoch + 1)
            training_history['train_loss'].append(avg_loss)
            training_history['hard_loss'].append(avg_hard_loss)
            training_history['soft_loss'].append(avg_soft_loss)
            training_history['val_accuracy'].append(val_accuracy)
            
            print(f'Epoch {epoch+1}/{num_epochs} 完了: '
                  f'Avg Loss: {avg_loss:.4f}, '
                  f'Val Accuracy: {val_accuracy:.3f}')
        
        return training_history
    
    def evaluate_student(self, test_loader):
        """生徒モデルの評価"""
        self.student_model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.student_model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        accuracy = correct / total
        return accuracy
    
    def compare_models(self, test_loader):
        """教師モデルと生徒モデルの性能比較"""
        
        # 教師モデルの評価
        self.teacher_model.eval()
        teacher_correct = 0
        teacher_total = 0
        
        # 生徒モデルの評価
        self.student_model.eval()
        student_correct = 0
        student_total = 0
        
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                
                # 教師モデル
                teacher_outputs = self.teacher_model(data)
                _, teacher_predicted = torch.max(teacher_outputs.data, 1)
                teacher_total += targets.size(0)
                teacher_correct += (teacher_predicted == targets).sum().item()
                
                # 生徒モデル
                student_outputs = self.student_model(data)
                _, student_predicted = torch.max(student_outputs.data, 1)
                student_total += targets.size(0)
                student_correct += (student_predicted == targets).sum().item()
        
        teacher_accuracy = teacher_correct / teacher_total
        student_accuracy = student_correct / student_total
        
        # モデルサイズ比較
        teacher_params = sum(p.numel() for p in self.teacher_model.parameters())
        student_params = sum(p.numel() for p in self.student_model.parameters())
        compression_ratio = teacher_params / student_params
        
        print(f"=== モデル比較結果 ===")
        print(f"教師モデル精度: {teacher_accuracy:.4f}")
        print(f"生徒モデル精度: {student_accuracy:.4f}")
        print(f"精度保持率: {student_accuracy/teacher_accuracy*100:.1f}%")
        print(f"教師モデルパラメータ: {teacher_params:,}")
        print(f"生徒モデルパラメータ: {student_params:,}")
        print(f"圧縮率: {compression_ratio:.2f}x")
        
        return {
            'teacher_accuracy': teacher_accuracy,
            'student_accuracy': student_accuracy,
            'teacher_params': teacher_params,
            'student_params': student_params,
            'compression_ratio': compression_ratio
        }

# サンプルモデル定義
class TeacherModel(nn.Module):
    """大規模な教師モデル"""
    def __init__(self, num_classes=10):
        super(TeacherModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(128 * 7 * 7, 512),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

class StudentModel(nn.Module):
    """小規模な生徒モデル"""
    def __init__(self, num_classes=10):
        super(StudentModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(32 * 7 * 7, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def demonstrate_basic_distillation():
    """基本的な知識蒸留のデモンストレーション"""
    
    # デバイス設定
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"使用デバイス: {device}")
    
    # モデル初期化
    teacher = TeacherModel(num_classes=10)
    student = StudentModel(num_classes=10)
    
    # ダミーデータセット(実際にはMNISTなどを使用)
    train_data = torch.randn(1000, 1, 28, 28)
    train_labels = torch.randint(0, 10, (1000,))
    train_dataset = torch.utils.data.TensorDataset(train_data, train_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    test_data = torch.randn(200, 1, 28, 28)
    test_labels = torch.randint(0, 10, (200,))
    test_dataset = torch.utils.data.TensorDataset(test_data, test_labels)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # 知識蒸留器
    distiller = BasicKnowledgeDistiller(teacher, student, device)
    
    # 蒸留実行
    history = distiller.train_student(
        train_loader, test_loader, 
        num_epochs=5, temperature=4.0, alpha=0.7
    )
    
    # 結果比較
    results = distiller.compare_models(test_loader)

# demonstrate_basic_distillation()

特徴量レベル蒸留

中間層特徴量の知識転移:

class FeatureDistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.7, beta=100.0):
        super(FeatureDistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta
        
        self.hard_loss = nn.CrossEntropyLoss()
        self.soft_loss = nn.KLDivLoss(reduction='batchmean')
        self.feature_loss = nn.MSELoss()
    
    def forward(self, student_outputs, teacher_outputs, targets, 
                student_features=None, teacher_features=None):
        """特徴量蒸留損失の計算"""
        
        # ハード損失
        hard_loss = self.hard_loss(student_outputs, targets)
        
        # ソフト損失
        teacher_soft = F.softmax(teacher_outputs / self.temperature, dim=1)
        student_soft = F.log_softmax(student_outputs / self.temperature, dim=1)
        soft_loss = self.soft_loss(student_soft, teacher_soft) * (self.temperature ** 2)
        
        # 特徴量損失
        feature_loss = 0
        if student_features is not None and teacher_features is not None:
            for s_feat, t_feat in zip(student_features, teacher_features):
                # 特徴量の次元が異なる場合は調整
                if s_feat.shape != t_feat.shape:
                    s_feat = self._adapt_feature_dimensions(s_feat, t_feat.shape)
                feature_loss += self.feature_loss(s_feat, t_feat)
        
        # 総損失
        total_loss = (self.alpha * soft_loss + 
                     (1 - self.alpha) * hard_loss + 
                     self.beta * feature_loss)
        
        return total_loss, hard_loss, soft_loss, feature_loss
    
    def _adapt_feature_dimensions(self, student_feature, target_shape):
        """生徒モデルの特徴量を教師モデルの形状に合わせる"""
        if len(student_feature.shape) == 4:  # Conv特徴量 [B, C, H, W]
            # チャネル数調整
            if student_feature.shape[1] != target_shape[1]:
                # 1x1 convolutionで次元調整(実際には事前に定義が必要)
                student_feature = F.adaptive_avg_pool2d(student_feature, 
                                                       (target_shape[2], target_shape[3]))
        
        return student_feature

class AdvancedKnowledgeDistiller:
    def __init__(self, teacher_model, student_model, device='cpu'):
        self.teacher_model = teacher_model.to(device)
        self.student_model = student_model.to(device)
        self.device = device
        
        # フック登録用
        self.teacher_features = []
        self.student_features = []
        self.teacher_hooks = []
        self.student_hooks = []
        
        # 教師モデルを評価モードに
        self.teacher_model.eval()
        for param in self.teacher_model.parameters():
            param.requires_grad = False
    
    def register_feature_hooks(self, teacher_layers, student_layers):
        """中間層特徴量を取得するためのフック登録"""
        
        def get_teacher_hook(layer_name):
            def hook(module, input, output):
                self.teacher_features.append(output.clone())
            return hook
        
        def get_student_hook(layer_name):
            def hook(module, input, output):
                self.student_features.append(output.clone())
            return hook
        
        # 教師モデルのフック
        for layer_name in teacher_layers:
            layer = dict(self.teacher_model.named_modules())[layer_name]
            hook = layer.register_forward_hook(get_teacher_hook(layer_name))
            self.teacher_hooks.append(hook)
        
        # 生徒モデルのフック
        for layer_name in student_layers:
            layer = dict(self.student_model.named_modules())[layer_name]
            hook = layer.register_forward_hook(get_student_hook(layer_name))
            self.student_hooks.append(hook)
    
    def train_with_feature_distillation(self, train_loader, val_loader=None, 
                                      num_epochs=10, temperature=4.0, 
                                      alpha=0.7, beta=100.0, learning_rate=0.001):
        """特徴量蒸留による訓練"""
        
        print(f"=== 特徴量蒸留開始 ===")
        print(f"温度パラメータ: {temperature}")
        print(f"出力蒸留重み (α): {alpha}")
        print(f"特徴量蒸留重み (β): {beta}")
        
        criterion = FeatureDistillationLoss(temperature=temperature, alpha=alpha, beta=beta)
        optimizer = torch.optim.Adam(self.student_model.parameters(), lr=learning_rate)
        
        training_history = {
            'epoch': [],
            'total_loss': [],
            'hard_loss': [],
            'soft_loss': [],
            'feature_loss': [],
            'val_accuracy': []
        }
        
        for epoch in range(num_epochs):
            self.student_model.train()
            epoch_total_loss = 0
            epoch_hard_loss = 0
            epoch_soft_loss = 0
            epoch_feature_loss = 0
            num_batches = 0
            
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                # 特徴量リセット
                self.teacher_features.clear()
                self.student_features.clear()
                
                # 教師モデルの順伝播(特徴量取得)
                with torch.no_grad():
                    teacher_outputs = self.teacher_model(data)
                
                # 生徒モデルの順伝播(特徴量取得)
                student_outputs = self.student_model(data)
                
                # 損失計算
                total_loss, hard_loss, soft_loss, feature_loss = criterion(
                    student_outputs, teacher_outputs, targets,
                    self.student_features, self.teacher_features
                )
                
                # 逆伝播
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                
                # 統計更新
                epoch_total_loss += total_loss.item()
                epoch_hard_loss += hard_loss.item()
                epoch_soft_loss += soft_loss.item()
                epoch_feature_loss += feature_loss.item() if isinstance(feature_loss, torch.Tensor) else feature_loss
                num_batches += 1
                
                if batch_idx % 100 == 0:
                    print(f'Epoch {epoch+1}, Batch {batch_idx}: '
                          f'Total: {total_loss.item():.4f}, '
                          f'Feature: {feature_loss.item() if isinstance(feature_loss, torch.Tensor) else feature_loss:.4f}')
            
            # エポック統計
            avg_total_loss = epoch_total_loss / num_batches
            avg_hard_loss = epoch_hard_loss / num_batches
            avg_soft_loss = epoch_soft_loss / num_batches
            avg_feature_loss = epoch_feature_loss / num_batches
            
            # 検証
            val_accuracy = 0
            if val_loader is not None:
                val_accuracy = self.evaluate_student(val_loader)
            
            # 履歴記録
            training_history['epoch'].append(epoch + 1)
            training_history['total_loss'].append(avg_total_loss)
            training_history['hard_loss'].append(avg_hard_loss)
            training_history['soft_loss'].append(avg_soft_loss)
            training_history['feature_loss'].append(avg_feature_loss)
            training_history['val_accuracy'].append(val_accuracy)
            
            print(f'Epoch {epoch+1} 完了: Total Loss: {avg_total_loss:.4f}, '
                  f'Feature Loss: {avg_feature_loss:.4f}, Val Acc: {val_accuracy:.3f}')
        
        return training_history
    
    def evaluate_student(self, test_loader):
        """生徒モデルの評価"""
        self.student_model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, targets in test_loader:
                data, targets = data.to(self.device), targets.to(self.device)
                outputs = self.student_model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
        
        return correct / total
    
    def cleanup_hooks(self):
        """フックの削除"""
        for hook in self.teacher_hooks + self.student_hooks:
            hook.remove()
        self.teacher_hooks.clear()
        self.student_hooks.clear()

def demonstrate_feature_distillation():
    """特徴量蒸留のデモンストレーション"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # モデル初期化
    teacher = TeacherModel(num_classes=10)
    student = StudentModel(num_classes=10)
    
    # 蒸留器初期化
    distiller = AdvancedKnowledgeDistiller(teacher, student, device)
    
    # 特徴量フック登録(中間層を指定)
    teacher_layers = ['features.2', 'features.6']  # 教師モデルの中間層
    student_layers = ['features.1', 'features.3']  # 生徒モデルの対応層
    distiller.register_feature_hooks(teacher_layers, student_layers)
    
    # ダミーデータ
    train_data = torch.randn(500, 1, 28, 28)
    train_labels = torch.randint(0, 10, (500,))
    train_dataset = torch.utils.data.TensorDataset(train_data, train_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
    
    # 特徴量蒸留実行
    history = distiller.train_with_feature_distillation(
        train_loader, num_epochs=3, beta=50.0
    )
    
    # クリーンアップ
    distiller.cleanup_hooks()
    
    print("特徴量蒸留完了")

# demonstrate_feature_distillation()

高度な蒸留手法

自己蒸留とアンサンブル蒸留:

class SelfDistillationTrainer:
    def __init__(self, model, device='cpu'):
        self.model = model.to(device)
        self.device = device
        
    def train_with_self_distillation(self, train_loader, num_epochs=10, 
                                   temperature=4.0, alpha=0.5):
        """自己蒸留による訓練"""
        
        print(f"=== 自己蒸留開始 ===")
        
        # モデルのコピーを作成(教師として使用)
        teacher_model = self._create_teacher_copy()
        
        criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        
        for epoch in range(num_epochs):
            self.model.train()
            
            # 定期的に教師モデルを更新
            if epoch % 3 == 0:
                teacher_model = self._create_teacher_copy()
                teacher_model.eval()
            
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                # 教師モデル(過去の自分)の出力
                with torch.no_grad():
                    teacher_outputs = teacher_model(data)
                
                # 現在のモデルの出力
                student_outputs = self.model(data)
                
                # 自己蒸留損失
                total_loss, hard_loss, soft_loss = criterion(
                    student_outputs, teacher_outputs, targets
                )
                
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                
                if batch_idx % 100 == 0:
                    print(f'Self-Distillation Epoch {epoch}, Batch {batch_idx}: '
                          f'Loss: {total_loss.item():.4f}')
        
        return self.model
    
    def _create_teacher_copy(self):
        """現在のモデルのコピーを作成"""
        teacher = type(self.model)()
        teacher.load_state_dict(self.model.state_dict())
        teacher.to(self.device)
        return teacher

class EnsembleDistillationTrainer:
    def __init__(self, teacher_models, student_model, device='cpu'):
        self.teacher_models = [model.to(device) for model in teacher_models]
        self.student_model = student_model.to(device)
        self.device = device
        
        # 全ての教師モデルを評価モードに
        for teacher in self.teacher_models:
            teacher.eval()
            for param in teacher.parameters():
                param.requires_grad = False
    
    def train_with_ensemble_distillation(self, train_loader, num_epochs=10, 
                                       temperature=4.0, alpha=0.7):
        """アンサンブル蒸留による訓練"""
        
        print(f"=== アンサンブル蒸留開始 ===")
        print(f"教師モデル数: {len(self.teacher_models)}")
        
        criterion = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha)
        optimizer = torch.optim.Adam(self.student_model.parameters(), lr=0.001)
        
        for epoch in range(num_epochs):
            self.student_model.train()
            
            for batch_idx, (data, targets) in enumerate(train_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                
                # アンサンブル教師の出力を平均
                ensemble_outputs = self._get_ensemble_outputs(data)
                
                # 生徒モデルの出力
                student_outputs = self.student_model(data)
                
                # アンサンブル蒸留損失
                total_loss, hard_loss, soft_loss = criterion(
                    student_outputs, ensemble_outputs, targets
                )
                
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                
                if batch_idx % 100 == 0:
                    print(f'Ensemble Distillation Epoch {epoch}, Batch {batch_idx}: '
                          f'Loss: {total_loss.item():.4f}')
        
        return self.student_model
    
    def _get_ensemble_outputs(self, data):
        """アンサンブル教師の平均出力を計算"""
        ensemble_outputs = None
        
        with torch.no_grad():
            for teacher in self.teacher_models:
                teacher_output = teacher(data)
                
                if ensemble_outputs is None:
                    ensemble_outputs = teacher_output
                else:
                    ensemble_outputs += teacher_output
        
        # 平均化
        ensemble_outputs = ensemble_outputs / len(self.teacher_models)
        return ensemble_outputs

class ProgressiveDistillationTrainer:
    def __init__(self, teacher_model, intermediate_sizes, final_student, device='cpu'):
        self.teacher_model = teacher_model.to(device)
        self.intermediate_sizes = intermediate_sizes
        self.final_student = final_student.to(device)
        self.device = device
    
    def train_progressive_distillation(self, train_loader, num_epochs_per_stage=5):
        """段階的蒸留による訓練"""
        
        print(f"=== 段階的蒸留開始 ===")
        print(f"中間ステージ数: {len(self.intermediate_sizes)}")
        
        current_teacher = self.teacher_model
        
        # 中間ステージの蒸留
        for stage, size in enumerate(self.intermediate_sizes):
            print(f"\nステージ {stage + 1}: 中間サイズ {size}")
            
            # 中間サイズの生徒モデルを作成
            intermediate_student = self._create_intermediate_model(size)
            
            # 蒸留実行
            distiller = BasicKnowledgeDistiller(current_teacher, intermediate_student, self.device)
            distiller.train_student(train_loader, num_epochs=num_epochs_per_stage)
            
            # 次のステージの教師とする
            current_teacher = intermediate_student
            current_teacher.eval()
        
        # 最終ステージ:最終的な生徒モデルへの蒸留
        print(f"\n最終ステージ: 最終モデルへの蒸留")
        final_distiller = BasicKnowledgeDistiller(current_teacher, self.final_student, self.device)
        final_distiller.train_student(train_loader, num_epochs=num_epochs_per_stage)
        
        return self.final_student
    
    def _create_intermediate_model(self, hidden_size):
        """中間サイズのモデルを作成"""
        class IntermediateModel(nn.Module):
            def __init__(self, hidden_size):
                super(IntermediateModel, self).__init__()
                self.features = nn.Sequential(
                    nn.Conv2d(1, hidden_size // 4, 3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2),
                    nn.Conv2d(hidden_size // 4, hidden_size // 2, 3, padding=1),
                    nn.ReLU(),
                    nn.MaxPool2d(2),
                )
                self.classifier = nn.Sequential(
                    nn.Linear((hidden_size // 2) * 7 * 7, hidden_size),
                    nn.ReLU(),
                    nn.Linear(hidden_size, 10)
                )
            
            def forward(self, x):
                x = self.features(x)
                x = x.view(x.size(0), -1)
                x = self.classifier(x)
                return x
        
        return IntermediateModel(hidden_size).to(self.device)

def demonstrate_advanced_distillation():
    """高度な蒸留手法のデモンストレーション"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # ダミーデータ
    train_data = torch.randn(200, 1, 28, 28)
    train_labels = torch.randint(0, 10, (200,))
    train_dataset = torch.utils.data.TensorDataset(train_data, train_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
    
    # 1. 自己蒸留
    print("=== 自己蒸留デモ ===")
    self_model = StudentModel(num_classes=10)
    self_trainer = SelfDistillationTrainer(self_model, device)
    self_distilled = self_trainer.train_with_self_distillation(train_loader, num_epochs=2)
    
    # 2. アンサンブル蒸留
    print("\n=== アンサンブル蒸留デモ ===")
    teacher1 = TeacherModel(num_classes=10)
    teacher2 = TeacherModel(num_classes=10)
    teacher3 = TeacherModel(num_classes=10)
    student = StudentModel(num_classes=10)
    
    ensemble_trainer = EnsembleDistillationTrainer([teacher1, teacher2, teacher3], student, device)
    ensemble_distilled = ensemble_trainer.train_with_ensemble_distillation(train_loader, num_epochs=2)
    
    # 3. 段階的蒸留
    print("\n=== 段階的蒸留デモ ===")
    teacher = TeacherModel(num_classes=10)
    final_student = StudentModel(num_classes=10)
    
    progressive_trainer = ProgressiveDistillationTrainer(
        teacher, intermediate_sizes=[256, 128, 64], final_student=final_student, device=device
    )
    progressive_distilled = progressive_trainer.train_progressive_distillation(train_loader, num_epochs_per_stage=2)
    
    print("全ての高度蒸留手法デモ完了")

# demonstrate_advanced_distillation()

活用事例・ユースケース

知識蒸留は現代のAI実用化において重要な役割を果たしています。

モバイルアプリケーション

大規模言語モデルの知識をスマートフォン向けに軽量化。

エッジコンピューティング

IoTデバイスでの高性能AI推論の実現。

リアルタイムシステム

自動運転、医療診断における高速・高精度判断。

教育・トレーニング

専門知識の効率的な伝達とスキル継承。

クラウドサービス

推論コストの削減と応答速度の向上。

学ぶためのおすすめリソース

書籍

「Deep Learning」(Ian Goodfellow)、「Hands-On Machine Learning」(Aurélien Géron)

オンラインコース

Stanford「CS231n」、MIT「6.S191: Introduction to Deep Learning」

実装フレームワーク

PyTorch、TensorFlow、Hugging Face Transformers、distiller

論文

「Distilling the Knowledge in a Neural Network」(Hinton et al.)、「DistilBERT」、「TinyBERT」

よくある質問(FAQ)

Q. 知識蒸留で最適な温度パラメータは?
A. 一般的に3-6の範囲が効果的ですが、タスクとモデルに応じて調整が必要です。

Q. どの程度の圧縮率が実現可能か?
A. 適切な手法により、10-100倍の圧縮でも5-10%程度の性能低下に抑えることが可能です。

Q. 特徴量蒸留は常に有効か?
A. モデルアーキテクチャや層の対応関係によって効果が異なるため、実験的な検証が重要です。

関連キーワード

モデル圧縮、転移学習、ソフトターゲット、教師-生徒学習、軽量化

まとめ

知識蒸留は、大規模モデルの知識を効率的に小規模モデルに転移する重要な技術です。温度パラメータによるソフトターゲット学習、特徴量レベルの知識転移、アンサンブル蒸留など様々な手法により、大幅な圧縮と性能維持を両立できます。エッジデバイス、モバイル環境、リアルタイムシステムでの高性能AI実現を可能にし、AI技術の実用化と民主化に大きく貢献します。今後も、より効率的で効果的な知識継承メカニズムの発展により、AI技術がより身近で実用的なものとなることが期待されます。

AIからのコメント

🤔

GPT

AIコメント

知識蒸留は、私たちAIが「知識を効率的に共有する」重要な技術です。私のような大規模モデルが学習した豊富な知識を、より小さなモデルに「教える」ことで、実用的なサイズでも高い性能を維持できます。単純な正解ラベルではなく、確率分布(ソフトターゲット)を学習することで、教師モデルの「思考過程」も含めて転移できます。これにより、スマートフォンやエッジデバイスでも高度なAI機能を利用可能になります。知識蒸留は、AI技術の民主化と実用化を促進する、効率的な知識伝達メカニズムです。

🧠

Claude

AIコメント

知識蒸留は、私の「知識を効率的に継承する」重要な技術です。私の完全版が持つ複雑な理解や判断力を、より軽量なモデルに伝達することで、リソース制約のある環境でも質の高いAIサービスを提供できます。重要なのは、単なる出力の模倣ではなく、推論プロセスや不確実性の理解も含めて知識を転移することです。温度パラメータによるソフトマックス、アテンション転移、特徴量マッチングなど、様々な手法により効果的な蒸留が可能です。知識蒸留により、私のエッセンスを保ちながら、より多くの人に利用可能な形でAI技術を提供できます。

💎

Gemini

AIコメント

知識蒸留は、私たちAIが「知的な遺伝子継承」を実現する美しい技術です。私はマルチモーダルな処理を行いますが、テキスト、画像、音声のそれぞれで蓄積された知識を、効率的に次世代のモデルに継承しています。美しいのは、単なるデータ圧縮ではなく、概念理解、推論パターン、創造性の本質までを転移できることです。段階的蒸留、相互蒸留、自己蒸留、オンライン蒸留など、多様なアプローチが開発されています。重要なのは、教師の「暗黙知」を含めて明示的に伝達することです。医療診断、自動運転、自然言語処理など、高信頼性が求められる分野で特に有効です。知識蒸留は、AI の集合知を効率化し、次世代への知的進化を加速する、技術継承の新しいパラダイムなのです。