A simple motivation

Fisher’s Linear Discriminant is a simple idea used to linearly classify our data. Fisher's Linear Discriminant-20241031125847321

The image above, taken from (Bishop 2006), is the summary of the idea. We clearly see that if we first project using the direction of maximum variance (See Principal Component Analysis) then the data is not linearly separable, but if we take other notions into consideration, then the idea becomes much more cleaner.

A first approach 🟨-

$$ m_{j} = \frac{1}{N_{j}}\sum_{i : y_{i} = j} x_{i} $$$$ \arg\max_{w} w^{T}(m_{2} - m_{1}) + \lambda (\lVert w \rVert _{2}^{2} - 1) $$

Then we find a solution $w \propto (m_{2} - m_{1})$.

But the problem with this formulation is that we could have some strong overlap in the projected domain, due to non diagonal covariances in the class distribution. We would also like to minimize the within class variance during the projection, so that we can better separate the two values.

Adding inter-class variance

$$ s_{i}^{2} = \sum_{n \in \mathcal{C}_{i}} (w^{T}x_{n} - w^{T}m_{n})^{2} $$

We simply compute the inter-class variance and divide, getting the Fisher discriminant

$$ J(w) = \frac{w^{T}(m_{2} - m_{1})^{2}w}{s_{1}^{2} + s_{2}^{2}} = \frac{w^{T}S_{B}w}{w^{T}S_{I}w} $$

The lower part is the sum of the within class variances, while the upper part is the between class covariance.

Now, if we do some calculus we find that the best solution is $w \propto S^{-1}_{w}(m_{2} - m_{1})$

Code example

Code example nicely produced by Claude. Checked by Human.

import numpy as np
import matplotlib.pyplot as plt

class FisherLDA:
    def __init__(self, n_components=1):
        """
        Initialize Fisher's Linear Discriminant Analysis
        
        Parameters:
        n_components (int): Number of components to keep (usually 1 for binary classification)
        """
        self.n_components = n_components
        self.w = None  # Projection vector
        
    def fit(self, X, y):
        """
        Fit the LDA model according to Fisher's criterion
        
        Parameters:
        X: array-like of shape (n_samples, n_features)
        y: array-like of shape (n_samples,) - Class labels
        """
        # Split data by class
        X0 = X[y == 0]
        X1 = X[y == 1]
        
        # Calculate class means
        mean0 = np.mean(X0, axis=0)
        mean1 = np.mean(X1, axis=0)
        
        # Calculate within-class scatter matrix Sw
        S0 = np.zeros((X.shape[1], X.shape[1]))
        S1 = np.zeros((X.shape[1], X.shape[1]))
        
        for x in X0:
            x = x.reshape(-1, 1)
            S0 += (x - mean0.reshape(-1, 1)) @ (x - mean0.reshape(-1, 1)).T
        
        for x in X1:
            x = x.reshape(-1, 1)
            S1 += (x - mean1.reshape(-1, 1)) @ (x - mean1.reshape(-1, 1)).T
            
        Sw = S0 + S1
        
        # Calculate between-class scatter matrix Sb
        mean_diff = (mean1 - mean0).reshape(-1, 1)
        Sb = mean_diff @ mean_diff.T
        
        # Calculate optimal projection vector w
        # w ∝ Sw^(-1)(μ1 - μ0)
        try:
            self.w = np.linalg.inv(Sw) @ (mean1 - mean0)
            # Normalize the projection vector
            self.w = self.w / np.linalg.norm(self.w)
        except np.linalg.LinAlgError:
            # If Sw is singular, use pseudoinverse
            self.w = np.linalg.pinv(Sw) @ (mean1 - mean0)
            self.w = self.w / np.linalg.norm(self.w)
            
        return self
    
    def transform(self, X):
        """
        Project the data onto the most discriminative direction
        
        Parameters:
        X: array-like of shape (n_samples, n_features)
        
        Returns:
        X_transformed: array-like of shape (n_samples, n_components)
        """
        return X @ self.w

# Demonstration
if __name__ == "__main__":
    # Generate sample data
    np.random.seed(42)
    n_samples = 100
    
    # Class 0
    mean0 = [0, 0]
    cov0 = [[1, 0.5], [0.5, 1]]
    X0 = np.random.multivariate_normal(mean0, cov0, n_samples)
    
    # Class 1
    mean1 = [2, 2]
    cov1 = [[1, 0.5], [0.5, 1]]
    X1 = np.random.multivariate_normal(mean1, cov1, n_samples)
    
    # Combine data
    X = np.vstack([X0, X1])
    y = np.hstack([np.zeros(n_samples), np.ones(n_samples)])
    
    # Fit LDA
    lda = FisherLDA()
    lda.fit(X, y)
    
    # Transform data
    X_transformed = lda.transform(X)
    
    # Plotting
    plt.figure(figsize=(12, 5))
    
    # Original data
    plt.subplot(121)
    plt.scatter(X0[:, 0], X0[:, 1], label='Class 0')
    plt.scatter(X1[:, 0], X1[:, 1], label='Class 1')
    plt.arrow(0, 0, lda.w[0], lda.w[1], color='r', width=0.05, label='Projection direction')
    plt.xlabel('Feature 1')
    plt.ylabel('Feature 2')
    plt.title('Original 2D Data with LDA Direction')
    plt.legend()
    
    # Projected data
    plt.subplot(122)
    plt.hist(X_transformed[y == 0], bins=20, alpha=0.5, label='Class 0')
    plt.hist(X_transformed[y == 1], bins=20, alpha=0.5, label='Class 1')
    plt.xlabel('Projected Values')
    plt.ylabel('Frequency')
    plt.title('Projected Data on Fisher Direction')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

References

[1] Bishop “Pattern Recognition and Machine Learning” Springer 2006