All Projects
AI/ML

Multi-Modal Alzheimer's Detection from Speech

About

This project is an advanced, multi-modal machine learning pipeline designed to detect Alzheimer's Disease (AD) from spontaneous speech. By combining deep acoustic representations, contextual linguistic embeddings, and handcrafted prosodic and paralinguistic features, the system holistically analyzes both what a patient says and how they say it. It leverages a novel co-attention mechanism to align verbal content with vocal cues, making it highly effective at identifying the subtle cognitive and linguistic markers of dementia in clinical audio samples.

Tech Stack

PyTorch
Hugging Face Transformers
Wav2Vec2
BERT
Librosa
spaCy
Python

Features

Multi-Modal Integration

Combines raw speech audio, text transcripts, prosodic metrics (pitch, jitter), and paralinguistic traits (speech rate, pauses).

Cross-Modal Co-Attention Fusion

Uses 3 stacked Co-Attention blocks to dynamically align acoustic cues with text.

Optimized for Small Medical Datasets

Utilizes frozen, large pre-trained encoders (~144M parameters) paired with a lightweight, trainable fusion layer (~6M parameters) to prevent overfitting.

Engineered Feature Extraction

Automatically extracts and normalizes handcrafted voice rhythm and linguistic fluency metrics.

Robust Training Pipeline

Includes 10-fold cross-validation, mixed precision (FP16) training, and advanced learning rate scheduling.

Architecture

01

Input Layer

Ingests raw audio waveforms (16kHz), CHAT-formatted transcripts, 6 prosodic features, and 4 paralinguistic features.

02

Feature Encoding

Utilizes a frozen wav2vec2-base-960h model for acoustic embeddings and bert-base-uncased for contextual word embeddings.

03

Co-Attention Fusion Layer

Instead of simple concatenation, the model uses stacked Multi-Head Cross-Modal Attention, allowing the audio stream to attend to the text stream and vice versa.

04

Feature Aggregation & Projection

Summarizes outputs via acoustic pooling and linguistic pooling. These 768-dim vectors are concatenated with prosodic/paralinguistic features (totaling 1546 dimensions) and compressed via a dense projection layer.

05

Classification Head

A final linear layer with softmax outputs binary logits representing the probability of Control (0) vs. Dementia (1).

Future Improvements

Real-Time Clinical Inference

Develop a lightweight streaming API to process patient audio and provide live diagnostic probabilities during clinical assessments.

Multi-Class Cognitive Profiling

Expand the binary classifier to detect and categorize varying stages of cognitive decline.

Explainability Dashboard

Build a visualization tool for clinicians that highlights the exact transcript words and audio segments that contributed most to the prediction.