제목: Robust Sleep Staging over Incomplete Multimodal Physiological Signals via Contrastive Imagination
저자: Qi Shen et al.
인용: 0 (until, 2024.04.19)
그룹: College of Medicine and Biological Information Engineering, Northeastern University, China
학회: NIPS2024
링크: https://neurips.cc/virtual/2024/poster/94475
우선 이 논문을 리뷰하는 이유는, 2024 NIPS에 실린 몇 안되는 수면단계분류 관련 논문이기 때문이다. 비록 이 논문의 리뷰가 끔찍할 지라도, 결과론적으로는 Poster section에 Accepted 되었다는 것과, SMCCL 방법론이 재미있었다고 느꼈기 때문이다.
개인적인 생각으로, 저자들이 리뷰를 통과한 이유는 별 알맹이 없는 수식들을 열거하고, 이 과정에서 내 기준에는 필요 없어보이는 Loss function 및 모듈들이 다수 등장하여 리뷰어들을 혼란스럽게 만들었기 때문이라고 생각한다.
ChatGPT가 많이 똑똑해져서, 이런 글을 쓸 필요도 없어졌지만, 모처럼 많은 Loss function + 복잡한 수식 + NIPS + 수면단계분류가 겹쳤기에, 혼자서 정리하고 넘어가단다면, 비슷한 논문을 보는데에 있어서 더욱 편하게 넘어갈 것이라 생각해 혼자 공부용으로 글을 남기게 되었다.
Abstract
Motivation:
- Traditional multi-modal automated sleep staging (ASS) methods assume the availability of complete, unimpaired physiological signals, which is unrealistic in real-world scenarios
- Missing modalities and impaired signals are common, requiring more robust methods to extract temporal context information effectively
Contributions:
To address these challenges, the paper proposes a unified framework comprising three key modules:
- MAIM (Modal Awareness Imagination Module)
- Handles missing modality problems by reconstructing missing signals using latent representations. - SMCCL (Semantic & Modal Calibration Contrastive Learning)
- Ensures semantic consistency across modalities while preserving their unique characteristics through multi-level contrastive learning - MCTA (Multi-level Cross-Branch Temporal Attention)
- Captures cross-scale temporal dependencies in physiological signals, improving temporal context representation
임상적으로, 신호에 노이즈가 발생하는 경우는 종종 발생한다. 따라서 본 논문에서 말하고자 하는 문제 해결 과정이 필수적이라 생각하지만, 결국 우리가 입력 데이터로 사용하는 신호들은 이러한 노이즈를 포함하여 측정 및 기록된다. 따라서 이러한 일이 실제 임상에서도 빈번하게 발생한다고 하더라도, 수면 데이터에는 언제 노이즈가 발생했는지 적지 않는 경우들도 많으며, 노이즈인지, 아니면 환자의 신호가 정말 문제가 있는건지 구분이 어려운 상황에서, 이런 신호들을 굳이굳이 원본 신호를 손상/왜곡시켜가며 복구한다음 성능을 올리는 것이 과연 정말로 필요한 일인지는 모르겠다. 정확도가 정말 많이 올라가는 것이 아니라면, 난 불필요한 작업이라 생각이 든다 (그리고 결과론적이지만 실제 성능도 별로였기에,, 이게 왜 Accepted? 라는 생각이 든다)
저자들에게 물어보고 싶은점이 있다면 Noise reconstruction algorithm에서 과연 이 데이터가 노이즈를 얼마나 잘 찾아내고, 복구했는지, 그 정도를 수치로 보고싶다. XML/Json 파일에서의 Artifact 라벨을 정답으로 해서의 복구 정확도? 이런 부분이 수치적으로 궁금하다. 물론 이건 ALL-channel temporal information의 문제이고, 본 논문에서 말하는 incomplete signal problem의 경우는 수치적으로 어떻게 측정해야할지 모르겠다. 노이즈된 신호인지 vs 실제 왜곡된 신호인지에 대한 구분은 너무 어려우니까,,
Introduction
1. ASS has seen significant advancements through multi-modal AI models that effectively fuse various PS
>> However, these models still face critical challenges in addressing sensor malfunctions, which are common in real-world scenarios, leading to impaired or missing data
2. Most ASS models rely on RNN to capture temporal dependencies through learnable hidden states, which are well-suited for sequential modeling
++ Recently, Transformer-based models have shown strong performance in sequence tasks, but their ability to model recurrent dependencies remains weaker compared to RNNs
3. Existing methods often focus on mining temporal correlations at a single level, either at the intra/inter-epoch level
>> This narrow focus limits their ability to fully capture the complex variability patterns present in time series PS
>> As a result, current temporal models struggle to understand these patterns comprehensively, thereby impacting the overall performance of sleep staging
저자들은 RNN에 비해 Transformer의 recurrent dependency가 떨어진다고 주장하지만,, 난 딱히 동의하지는 않는다. 추측해보자면, 저자들이 제안하는 모듈에서 RNN이 Transformer 계열보다 성능이 높게 나왔지 않았을까라고 생각한다. 아마도 MCTA 부분에서 Transformer를 사용했을 때 성능의 이점이 없었을 것이라 생각한다. 개인적인 경험으로도 MCTA와 유사한 알고리즘을 Transformer에 적용해봤을 때 성능이 별로였던 기억이 존재한다.
Related Work 1
CoRe-SleepNet [링크]
>> 두 개의 Transformer 모델을 Raw signal과 Spectrogram을 입력으로 넣어 Cross-Attention기반 Fusion 방식을 보여주었다
Related Work 2
두 번째 논문은 XsleepNet이라고 하는, 후이 판 교수님 논문이다. Gradient blending (CVPR2021)을 활용하여 Raw signal과 Spectrogram을 최적화 시켰다.
이 연구의 Contribution은 Multi-modal 에서의 성능 향상은 크지 않았기에, Multi-View에서의 성능 상승이 기여라고 생각한다.
Methodology - Definition
>> 기존 연구 (CoRe-SleepNet)과 다른 점은 Chunk-based missing pattern을 사용하여 ALL-channel (All modality)가 특정 epoch, 시간대에 노이즈 혹은 손상이 되어도 효과적으로 복구를 할 수 있게 설계되었다는 점이다.
>> 이 과정에서 저자들이 가장 먼저 설명하는 특징으로는 Mask matrix Z의 사용이다.
위와 같은 Mask Matrix Z가 존재할 때 위첨자 j는 modalitiy를 의미하고, 밑의 i는 sequenial information이라고 보면 된다
~X는 incomplete multimodal signals이다. 이 것을 T라는 단위로 신혼의 길이인 N을 나눠 chunk를 만들고 ->N 으로 표기한다
S라는 modality matrix를 정의는 위와 같다.
Methodology - Model Architecture
우선 Main Figure는 다음과 같지만, 이 그림만 보면 확 와닿지는 않는다. 따라서 이를 좀 수정하자면
이 그림을 좀 더 아래에서 풀어 나가기 쉽게 내 방식으로 그림을 재분배하자면,
Step 1 -> Recovering process and Modality Alignmnet
Step 2 -> MCTA process 라고 보면 된다.
Methodology - MAIM
Recovering process의 목표
>> Incomplete view에서 Encoder - Decoder를 통해서 상호보완을 통해 Complete View로 만든다.
Encoder에서는 앞서 정의한 Z를 활용하여 Encoder Ej를 통해 Incomplete view X를 압축하고 이 압축된 latent representation을 바탕으로 incomplete view 인지 아니면 Normal signal인지 확인해보는 과정을 거친다.
결과적으로 fi 는 i-th epoch에서 가용한 모든 모달리티의 정보를 종합한 다중 모달리티의 공유 표현이 된다. 비록 incomplete view X를 입력으로 넣었을지라도, 최종 과정에서 Multimodal Variational AutoEncoder (MVAE)를 통해 modality fusion이 발생하므로, fi는 complete view X에 의해 압축됐다고 저자들은 주장한다.
Decoder에서는 앞서서 Incomplete view라고 여겨졌던 X를 복구하기 위한 process라고 보면 된다. MAE와 KL divergence를 활용한다고 보면 되며, N(0, 1)에 가깝게 복구한다고 보면 된다.
Methodology - SMCCL
앞서 말한 것 처럼, 복구와 동시에 Multi-modal 정보를 잘 학습하기 위한 시도를한다
기존 Multi-modal과는 다른점이 있다면 3개의 순위로 나누어 Contrastive Learning을 수행한다. 기존 연구에 있어서는 positive/negative sample로 나누어 Contrastive Learning을 수행하지만, 여기는 그러한 쌍이 3개가 된다고 보면 된다
1st: Same Class in Same Modality
2nd: Same Class in Different Modality
3rd: Different Class in Different Modality
그리고 이것을 충족시키기 위한 수식은 4번 수식이다
위의 4번 수식이 라벨에 대한 수식이라고 한다면, 이를 업데이트 하기 위한 Contrastive Learning 수식은 아래와 같다.
1은 indicator로써 modality가 같은지 다른지 확인용도로 쓰이고, 1부터 M까지 돌기 떄문에, θk는 0~1 사이가 된다.
I(,)는 Mutual information, H(,) Joint Entropy 이며, 이를 증명하는건 Supplementary 참고하면 된다. 그 결과는 6번이다.
θk가 높게 나온다면, mutual information이 높다는 것이지만, 반대로 joint entropy가 낮다는 말이 되기도 한다.
따라서 이를 정리한 최종 수식은 아래와 같다.
Methodology - MCTA
CIMSleepNet 의 구조는 다음과 같다. Intra-epoch 사이의 MCTA 알고리즘이 들어가는 것을 알 수 있다.
>> 위의 그림을 수식으로 접목시키면 아래와 같다. 별 내용이 없어서 수식만 보고 이해해도 충분할 것 같다.
Performance
>> 당연히 Incomplete 한 상황에서는 복구 작업이 존재하는 CIMSleepNet이 더 좋을 것이고, Complete 한 상황에서 모델 성능이 좋아져야, "우리가 모르는 노이즈 신호 데이터가 많구나" 이렇게 생각할텐데,,, 성능도 더 낮다. 모델의 학습이 잘 이루어졌다 가정 할 때, 이 결과는 수면단계분류에서 이런 노이즈가 별로 영향을 안끼치지 않았을까라고 이해해도 되지않나 생각한다.
아래 그림은 CoRe-SleepNet과의 Incomplete한 상황에서의 성능비교인데, 굳이 왜 이런 시도를 했는지 모르겠다. CoRe-SleepNet의 Contribution은 EEG, EOG를 융합하는데에 있지, Missing modality를 복구하는데에 있지는 않기 때문이다. 차라리 다른 Missing modality 복구작업과 비교하는 테이블이 있었다면 좋았을 것 같다
마지막으로 Visualization이다. 이 그림을 보고 느낀점은 SMCCL의 복구 과정이 확실히 참신하다고 느꼈는데, Embedding space또한 잘 구분 되었다고 생각한다. 차라리 별 이상한거 빼고 이 SMCCL만 강조했으면 더 좋은 논문이 되지 않았을까 싶다.
물론 Table 3를 보면 SMCCL의 장점이 성능에서도 두드러지게 나타난다고 생각한다. 그치만,, 내가 생각하는 이 논문의 성능 패착 원인은 백본 Network의 성능이 애초에 낮은 걸 써서 그렇다고 생각한다.
아니,, 좋은 backbone많은데,, MAIM 이런거 개발하자고,, 억지로 제안하는 알고리즘에 성능 올라가는 백본을 끼워 맞춰서 성능이 별로였다고 생각한다. 기존에 좋은 모델들은 이런거 안써도 잘 Modality 간 학습을 통해 성능을 뽑아내기에, 성능이 높지 않나 생각한다.
내 개인적인 생각은 4번 리뷰어의 첫번째 지적과 가장 유사한 것 같다. 너무 비현실적인 실험 세팅, 그리고 너무 낮은 성능을 보였다고 생각한다.