Web Analytics Made Easy - Statcounter

Bouncing ideas 생각 작업실/exp sharing 경험.실험 나누기

Fairseq transformer model에서 attention 뽑아내기

sleepy_wug 2024. 3. 8. 03:06

0. 요약

이 포스팅은 Fairseq을 이용해 train한 transformer model에서 attention weights를 뽑아내기 위한 노력의 과정을 기술한다.

 

나의 업무 환경이 이렇게 좋으면 참 멋질텐데

 

목차

     

    1. 이슈

    Fairseq이 attention weights를 순순히 내놓지 않는다. 선전포고다!

     

     

    우선 왜 attention weights를 뽑아내야 하는지에 대한 맥락부터 서술하고 시도해본 해결책 + 과정을 섹션 4부터 설명한다.

     

    3. 맥락

    Fairseq model이 꽤나 괜찮은 성능을 보였다. (Fairseq입문) (IPA변환기)

     

    이론가로서 나에게 중요한 건 모델의 성능 그자체는 아니다. 촘스키가 말했듯 언어학은 "engineered solution"이 아니라 "real solution"이 필요하다.[각주:1] 답은 나왔다. 엔지니어는 답을 내는 것이 목적이다. 더 효율적으로 빨리 내는 것이 방향성이다. 반면 "왜" 그 답이 나왔느냐를 설명하는 것이 이론가의 할 일이다. 다른 정보 없이 오직 음소배열 형태만 주었는데도 Transformer model은 "왜" 그 결정을 내렸나?

     

    그래서 경음화를 결정하는 요인을 파악하기 위해 attention weights을 뽑아내야 한다. 그렇게 할 수만 있다면 BertViz 등의 툴을 통해 NLP model의 attention을 시각화할 수 있다.

     

    BertViz tutorial[링크]에 보면, 곧장 시각화 예시부터 볼 수 있다.

     

    널리 알려져 있다시피 그리고 상식적으로 장애음뒤경음화(POT)는 선행 장애음이 경음화의 원인이 될 것이다. 그러나 L-Tensification의 경우 무조건 경음화가 아닌데, 단어의 어떤 부분을 보고 경음화 할지 (혹은 하지 않을지) 결정하는 것일까? 내 프로젝트가 제시하는 가설은 어두자음과 어중 양순음뒤 모음이 큰 요인이 된다는 것이다. 즉, 어떤 단어가 처음부터 ㄹ로 시작하면 L-Tensification을 안 적용할 likelihood가 높다. LT가 우선 한자어에서 (무슨이유에서든) 시작되었다면, ㄹ로 시작하는 차용어 냄새가 강하게 나는 단어에는 '안전하게' 안 적용 될 것이다. 마찬가지 논리로 양순음 뒤 ㅡ모음이 출현하는 브, 프 등은 단어정체성 파악에 매우 큰 힌트를 줄 것이다. LT를 시각화해보면 어두 자음 혹은 양순음뒤 모음에 불이 켜질것을 기대한다.

     

    Attention을 시각화하자!

     

    3.1 fairseq function을 wrapping하기

    그런데 문제는 Fairseq 모델이 돌아가는 과정에서 attention weight를 뽑아내기가 만만치 않다는 것이다. 그냥 Fairseq에서 쉽게쉽게 안 내놓는다.

     

    직접 fairseq의 codebase에 손을 대는 건 안 될 일이기 때문에, 차근차근 한 레이어 씩 순차적으로 되짚어가며 encoder self attention weight, decoder self attention weight, 그리고 cross_attention_weight를 뽑아내는 아래의 코드를 짰다.

    def forward_pass_with_attention_extraction(model, input_tokens):
        # Encoder forward pass
        encoder_self_attention_weights = []
        input_tokens = input_tokens.unsqueeze(0)
        encoder_output = model.encoder.embed_tokens(input_tokens)
        encoder_padding_mask = torch.zeros_like(input_tokens, dtype=torch.bool)  # no padding
        for layer in model.encoder.layers:
            encoder_output, self_attn_weights = layer.forward(encoder_output, encoder_padding_mask)
            encoder_self_attention_weights.append(self_attn_weights)
    
        # Decoder forward pass with a prepared target sequence (e.g., start tokens)
        decoder_input = torch.tensor([[model.task.target_dictionary.bos()]]).to(input_tokens.device)
        decoder_self_attention_weights = []
        decoder_cross_attention_weights = []
        for layer in model.decoder.layers:
            # For simplicity, assume decoder layers are adapted to return both self and cross-attention weights
            decoder_output, self_attn_weights, cross_attn_weights = layer(decoder_input, encoder_output)
            decoder_self_attention_weights.append(self_attn_weights)
            decoder_cross_attention_weights.append(cross_attn_weights)
            # Prepare decoder_input for the next layer or time step based on decoder_output
    
        return encoder_self_attention_weights, decoder_self_attention_weights, decoder_cross_attention_weights

     

    문제는 encoder_padding_mask에 zero like tensor를 pass했을 때 아래의 에러에 직면한다는 것. 모든 element가 False일 때도 0일 때도 동일한 에러가 났다.

    Traceback (most recent call last):
      File "...\2024-01-26 model_outputs\sandbox.py", line 62, in <module>
        forward_pass_with_attention_extraction(model=underlying_model,
      File "...\2024-01-26 model_outputs\sandbox.py", line 30, in forward_pass_with_attention_extraction
        encoder_output, self_attn_weights = layer.forward(encoder_output, encoder_padding_mask=encoder_padding_mask)
      File "...\miniconda3\envs\fairseq\lib\site-packages\fairseq\modules\transformer_layer.py", line 319, in forward
        output = torch._transformer_encoder_layer_fwd(
    RuntimeError: Mask Type should be defined

     

    내 모델은 encoder padding이 없기 때문에 논리적으로는 이렇게 주는 것이 타당하다. RuntimeError라는 데에서, 그리고  fairseq 코드 이후 더 깊이 Traceback이 이루어지지 않는다는 데에서, 아마도 Fairseq 문제가 아니라 Torch문제일 것 같다고 생각했다.

     

    4. 해결과정

    4.1 Torch downgrade

    몇 시간 동안 다양한 방법을 시도하였으나, 문제는 의외로 심심하게 풀렸다. Torch를 downgrade하니 해결되었던 것. 아마도 Torch codebase의 업데이트에 Fairseq이 발맞추지 않기 때문인 것 같다.

    Fairseq issue를 뒤지다가 해당 이슈에서 해결에 힌트를 받았다. 

    https://github.com/facebookresearch/fairseq/issues/4899

     

    Fairseq-generate giving me the error: 'RuntimeError: Mask Type should be defined' on Colab · Issue #4899 · facebookresearch/fa

    Some background: I'm working on a translation problem where I am able to get through the fairseq-preprocess and fairseq-train but during the process of fairseq-generate, the operation fails in the ...

    github.com

     

    나는 그냥 pip install 했기 때문에 torch 2.1.2 (아마도 최신버전?)를 사용했는데 그게 문제였던 것.

    # pip show torch
    Name: torch
    Version: 2.1.2
    Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
    Home-page: https://pytorch.org/
    Author: PyTorch Team
    Author-email: packages@pytorch.org
    License: BSD-3
    Location: c:\users\stanley\miniconda3\envs\fairseq\lib\site-packages
    Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
    Required-by: bertviz, fairseq, torchaudio

     

     

    다음부터는 fairseq 돌릴 때에는 아래의 command로 설치하여서 torch 버전을 1.12.1+cu113로 맞추어야 할 것이다.

    pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113

     

    이제 일단 돌아가기는 한다. 그러나 문제는 여전히 있는 듯하다. 왜냐하면 어떤 모델이든 (내 한국어 음운론 모델이든, Fairseq github repo에서 제공되는 WMT model이든) uniform attention이라는 괴랄한 결과를 내놓기 때문이다.

     

    encoder, decoder, cross-attention 모두 uniform하다는 것은, 모델이 input output 내의 각 요소에 모두 평등하게 attention을 준다는 건데, 그렇다면 퍼포먼스가 결코 나올 수 없기 때문이다.

     

    어떤 변형을 어디는 하고 다른데는 안 하기로 했으면 그 이유가 있을텐데, attention을 부분적으로 준 것이 아니라면 설명이 안 된다.

     

    You got me looking for attention!

    https://youtu.be/CHp0Kaidr14

     

     

    4.2 내부에 첩자 심기

    터부시된다는 걸 알지만 fairseq codebase에 손을 대기로 했다. 이게 왜 터부시되느냐하면, 나는 fairseq을 분석도구로 이용하기 때문이다.

     

    과학적 연구에서 분석도구에 연구자가 직접 손대는 것은 대체로 회피된다. 두 가지 이유가 있다. 첫째는 주관적 편향이 개입될 수 있기 때문이고 둘째는 재현가능성(reproducibility)을 위해서다. 주어져있는 도구가 있으면 그걸 순정으로 사용하는 게 권장된다. 

     

    function call 되는 fairseq MultiheadAttention class의 forward function가 내부연산을 위해 반드시 attention을 (이성적으로 생각했을 때 non-uniform한 attention을) 출력함에도 불구하고 내가 custom function으로 접근할 수 있는 것은 뭔가 잘못되어 있는 게 분명하다. 추측하건대, 나의 custom function이 받는 결과는 뭔가 levelling된 결과치인 듯하다. (즉 편차가 있는 attention weights를 외부 모듈로 내보낼 때 모종의 이유(?)로 평균치만 반출함)

     

    그래서 fairseq codebase의 forward() function에서 return 라인 직전에 아예 input() function을 박아버렸다.

            if need_weights:
                attn_weights = attn_weights_float.view(
                    bsz, self.num_heads, tgt_len, src_len
                ).transpose(1, 0)
                if not need_head_weights:
                    # average attention weights over heads
                    attn_weights = attn_weights.mean(dim=0)
                    
            input(f"inside MultiheadAttention forward(): \nattn:{attn}\nattn_weights{attn_weights}")
            return attn, attn_weights

     

    쉽게 말하자면 "연산 다 하고 나서 결과치 return하기 전에 나한테 확인받아. 내가 Enter 눌러야 지나갈 수 있어!" 하는 것이다.

     

    그랬더니 아래와 같이 논리적인 결과치를 얻을 수 있었다. (마지막 encoding layer에서 나온 값)

    attn:tensor([[[-0.2132,  1.6754,  1.2194, -0.2058,  0.2805,  0.0803,  0.2405,
               1.4939, -1.9362, -0.7559,  0.3412,  1.1786,  0.4615, -1.3100,
              -2.2807, -1.6114,  1.3796,  2.1927, -2.5260, -0.2914, -1.0734,
               1.0780, -0.3484, -0.7229, -2.5685,  0.1950,  2.8701,  0.0945,
              -0.2877,  0.9932, -2.0540, -1.4638, -0.3367, -1.5930,  2.7986,
               1.3678, -0.3007,  0.3272,  0.0844, -1.0528, -0.8851,  1.9094,
              -0.9410,  1.2149, -1.8184,  1.4367,  0.0538, -2.0376, -0.6654,
               0.9242,  0.2909, -3.2609,  2.4890,  1.4103,  0.9503,  0.5516,
               3.5299,  2.9578, -2.8321, -0.1681,  0.0710, -1.0526,  1.8459,
              -0.1562],
             [-0.2156,  1.6814,  1.2226, -0.2251,  0.2751,  0.0800,  0.2395,
               1.4976, -1.9262, -0.7625,  0.3610,  1.1767,  0.4708, -1.3056,
              -2.2812, -1.6046,  1.3759,  2.1898, -2.5154, -0.2943, -1.0719,
               1.0785, -0.3405, -0.7148, -2.5699,  0.1903,  2.8672,  0.0927,
              -0.2973,  1.0055, -2.0620, -1.4732, -0.3320, -1.5934,  2.7934,
               1.3686, -0.2876,  0.3350,  0.0687, -1.0580, -0.8973,  1.9121,
              -0.9232,  1.2163, -1.8185,  1.4359,  0.0578, -2.0348, -0.6611,
               0.9331,  0.3001, -3.2523,  2.4929,  1.4105,  0.9462,  0.5412,
               3.5355,  2.9551, -2.8449, -0.1743,  0.0666, -1.0537,  1.8505,
              -0.1483],
             [-0.2079,  1.6609,  1.2157, -0.2000,  0.2835,  0.0776,  0.2367,
               1.4868, -1.9322, -0.7489,  0.3435,  1.1754,  0.4643, -1.2977,
              -2.2636, -1.6001,  1.3789,  2.1766, -2.5075, -0.2817, -1.0705,
               1.0746, -0.3481, -0.7195, -2.5579,  0.1830,  2.8664,  0.1019,
              -0.2882,  0.9803, -2.0434, -1.4623, -0.3324, -1.5895,  2.7771,
               1.3603, -0.2920,  0.3242,  0.0822, -1.0457, -0.8658,  1.8925,
              -0.9451,  1.2097, -1.8141,  1.4297,  0.0426, -2.0254, -0.6666,
               0.9148,  0.2875, -3.2465,  2.4748,  1.4068,  0.9386,  0.5519,
               3.5059,  2.9420, -2.8108, -0.1725,  0.0685, -1.0514,  1.8344,
              -0.1548],
             [ 1.3226, -0.1522, -0.0240,  0.1194,  0.1759, -0.4110,  0.0858,
               0.1720, -0.3133,  0.4628,  0.3935,  1.6367,  0.4127,  0.5739,
               1.4142,  0.9822, -0.0548, -0.5897,  0.9229,  1.6793,  0.4204,
              -0.2348,  0.1772,  1.6079, -0.1124, -2.3561,  0.8576,  1.1747,
               1.2932, -2.8568, -0.9354, -2.1837,  1.1670, -0.0133, -1.3205,
              -0.6328,  1.1438, -0.5948, -1.1604,  0.2787,  2.5329, -1.3090,
              -1.6825, -1.3178, -0.0966,  0.3489, -1.2427, -1.2287, -0.4729,
              -1.0876, -1.8181, -1.0772, -1.8703,  0.1277,  0.0992,  1.6601,
              -1.7304, -0.1690,  0.7692, -0.1420,  0.6178,  0.7694,  0.0855,
              -1.0914],
             [-0.2056,  1.6656,  1.2233, -0.2213,  0.2813,  0.0735,  0.2379,
               1.4947, -1.9187, -0.7564,  0.3627,  1.1732,  0.4710, -1.2952,
              -2.2740, -1.6074,  1.3691,  2.1749, -2.5065, -0.2895, -1.0712,
               1.0807, -0.3409, -0.7089, -2.5610,  0.1806,  2.8592,  0.0899,
              -0.2993,  0.9882, -2.0522, -1.4692, -0.3294, -1.5884,  2.7715,
               1.3655, -0.2865,  0.3323,  0.0708, -1.0475, -0.8775,  1.9051,
              -0.9277,  1.2112, -1.8002,  1.4328,  0.0421, -2.0230, -0.6562,
               0.9297,  0.2984, -3.2506,  2.4798,  1.4110,  0.9486,  0.5466,
               3.5200,  2.9446, -2.8299, -0.1808,  0.0659, -1.0520,  1.8444,
              -0.1434]]])
    attn_weights:tensor([[[[1.5290e-04, 3.3297e-04, 7.4158e-04, 5.3844e-03, 9.7565e-01,
               1.7738e-02]],
             [[6.3505e-05, 1.7963e-04, 7.7132e-04, 6.7409e-03, 9.7707e-01,
               1.5170e-02]],
             [[2.2172e-04, 4.6036e-04, 8.4506e-04, 5.1407e-03, 9.7050e-01,
               2.2828e-02]],
             [[6.0432e-04, 1.9357e-04, 1.3874e-02, 2.2387e-02, 1.3759e-02,
               9.4918e-01]],
             [[2.3214e-04, 4.2990e-04, 6.7530e-04, 5.7531e-03, 9.7518e-01,
               1.7730e-02]]],
            [[[3.2012e-03, 8.4666e-03, 2.7519e-03, 5.6215e-02, 9.2797e-01,
               1.3966e-03]],
             [[2.9338e-03, 6.7742e-03, 2.0110e-03, 6.2265e-02, 9.2410e-01,
               1.9213e-03]],
             [[3.4868e-03, 1.0847e-02, 2.5210e-03, 5.7555e-02, 9.2420e-01,
               1.3940e-03]],
             [[2.3742e-02, 2.1352e-04, 3.2380e-03, 9.3723e-03, 4.7414e-02,
               9.1602e-01]],
             [[3.8614e-03, 1.1674e-02, 2.3378e-03, 6.3152e-02, 9.1777e-01,
               1.2073e-03]]]])

     

     


    • 아래에 댓글창이 열려있습니다. 로그인 없이도 댓글 다실 수 있습니다.
    • 글과 관련된 것, 혹은 글을 읽고 궁금한 것이라면 무엇이든 댓글을 달아주세요.
    • 반박이나 오류 수정을 특히 환영합니다.
    • 로그인 없이 비밀글을 다시면, 거기에 답변이 달려도 보실 수 없습니다. 답변을 받기 원하시는 이메일 주소 등을 비밀글로 남겨주시면 이메일로 답변드리겠습니다.

     

    1. 아마 1995년 노랑책 Minimalist Program에서, 1980년대까지의 지배결속이론의 성격과 그 복잡성을 한 마디로 표현하던 부분에서 언급된 표현인 것으로 기억한다. 정확한 출처는 찾아봐야 함. [본문으로]
    반응형