[SageMaker] SageMaker Jumpstart를 사용한 LLM Fine Tuning - Supervised fine tuning (SFT)

Sagemaker를 활용하여 Fine-tuning을 수행하였습니다.

파인튜닝이란?

파인튜닝이란 LLM의 성능을 유지하고 개선하기 위한 수단으로  특정 데이터를 활용하여 LLM을 학습시키는 방법을 의미합니다.파인튜닝의 경우 다음과 같은 장점이 있습니다.

하지만 단점으로는 다음과 같습니다.

따라서 파인튜닝이란 대부분의 경우 사용에 적합하진 않으며, 모델에 따라서 복잡한 초기 구성이 필요할 수 있습니다.뿐만 아니라 LLM을 학습시키기 위해 필요한 데이터 혹은 구성 방법 등이 잘 알려져 있지않은 경우가 대부분입니다. 하지만 SageMaker Jumpstart같은 경우 LLM 모델을 학습시키기 위해 필요한 데이터의 정보나 구성 등을 확인할 수 있고 클라우드 상에서 쉽게 LLM 모델을 파인튜닝하여 사용할 수 있게 됩니다.

Supervised fine tuning (SFT)란?

SFT는 레이블이 지정된 데이터를 사용하여 사전 훈련된 LLM을 특정 다운스트림(특정 작업이 가능) 작업이 가능하도록 적용하는데 사용할 수 있는 LLM을 학습시키는 기법 중 하나입니다.SFT에서 사전 훈련된 LLM은 지도 학습 기술을 사용하여 레이블이 지정된 데이터 세트에서 미세 조정됩니다. 모델의 가중치는 LLM의 예측과 실제 레이블 간의 차이를 측정하는 작업별 손실에서 파생된 기울기를 기반으로 조정됩니다.SFT는 일반적으로 3단계 프로세스를 거칩니다.

  1. 1단계: 사전 훈련 — 기본 또는 기초 모델은 처음에 대규모 데이터 세트에서 훈련되어 주어진 문장에서 다음 단어를 예측하여 언어 패턴, 문법 및 컨텍스트를 이해하는 방법을 학습합니다. 이 단계는 모델이 언어에 대한 광범위한 이해가 가능하도록 설정할 수 있습니다.
  2. 2단계: 데이터 라벨링 — 미세 조정에 사용되는 데이터 세트를 준비합니다. 각 데이터 포인트에는 올바른 출력이나 답이 표시되어 있습니다. 이 레이블이 지정된 데이터는 미세 조정 프로세스 중에 모델이 매개변수를 조정하도록 안내하므로 지도 학습에서 가장 중요한 단계입니다.
  3. 3단계: 미세 조정 — 사전 훈련된 모델은 레이블이 지정된 데이터가 있는 작업별 데이터 세트에 대해 추가로 훈련됩니다. 모델은 이 특정 작업의 성능을 향상시키기 위해 매개변수를 조정합니다. 이러한 작업은 텍스트 분류, 감정 분석, 질문 답변 시스템 등 다양한 곳에서 사용할 수 있습니다.

EX) 라벨링된 테스트 데이터는 다음과 같이 구성되어 있습니다.

{"question":"What part of a plant protects the plant cell, maintains its shape, and prevents excessive uptake of water?","correct_answer":"wall"}
{"question":"Bones, cartilage, and ligaments make up what anatomical system?","correct_answer":"skeletal system"}
{"question":"What is the term for the force exerted by circulating blood on the walls of blood vessels?","correct_answer":"blood pressure"}
{"question":"What is the most diverse and abundant group of organisms on earth, numbering in the millions of trillions?","correct_answer":"bacteria"}
{"question":"Energy from sunlight enters many ecosystems through what process?","correct_answer":"photosynthesis"}
{"question":"Modern plants reflect what kind of changes that have occurred over many, many years?","correct_answer":"evolutionary changes"}

이러한 라벨링 데이터를 학습시키기 위해 파인튜닝하기 위한 기법으로는 LoRA가 대표적이며 훈련 가능한 매개변수의 수를 줄이고 미세 조정을 더욱 효율적으로 만들 수 있습니다.

발생할 수 있는 문제

보편적으로 파인튜닝에서 발생할 수 있는 문제들이 포함되어 있습니다.하지만 SFT 방식은 LLM이 훈련하는 기본적인 방식으로 성능을 향상시키는 가장 유용한 방식으로써 사용될 수 있습니다.

파인튜닝 설정하기

1단계: 파인튜닝 가능한 모델(Llama2 7B Chat)을 선택합니다.

Jumpstart > Meta > Llama 2 7B Chat 모델을 선택하고 Train 버튼을 클릭합니다.

추가적으로 문서를 확인해보면 새로운 데이터셋을 통해 파인튜닝하는 방법을 안내하고 있습니다.

2단계: 파인튜닝 세부 구성

Datasetinstructioncontext⁠가 통합되어야 하지만 좀 더 명확한 지침을 내리기 위해 프롬프트 템플릿을 생성하고 Question과 Answer로 구성된 데이터셋을 사용하여 전달합니다.

{"question":"What part of a plant protects the plant cell, maintains its shape, and prevents excessive uptake of water?","correct_answer":"wall"}

해당 트레이닝에서는 Epochs와 Instruction-Train The Model과 Chat Dataset Format만 설정하고 나머지는 디폴트로 설정하였습니다. 파라미터의 설명은 다음과 같습니다.

3단계: 모델 저장 및 훈련 등록

훈련된 모델을 저장할 위치를 설정합니다.

4단계 작업 완료

Jobs > Training > Training Jobs에서 실행중인 Jumpstart 모델을 확인할 수 있습니다.해당 Training Job이 완료되면 다음과 같이 퍼포먼스 결과가 출력됩니다.

5단계: 테스트

해당 모델을 배포하여 성능 테스트를 진행해보도록 하겠습니다.

이후 Deployments > Endpoints로 모델이 자동적으로 배포되며 GPU 상황에 따라 배포되며 배포가 실패할 경우 GPU 상태를 확인해보시는 것을 권장드립니다.배포가 완료되었다면 생성된 Endpoint ID를 endpoint_name으로 설정하여 테스트를 진행할 수 있습니다.

sft_endpoint_name = "jumpstart-dft-meta-textgeneration-l-20240625-064708"

아래 테스트 코드와 평가를 위한 데이터셋입니다.

위 테스트 데이터셋을 통해 출력한 결과입니다.

6단계: SageMaker Test Inference

현재 배포된 SageMaker Endpoint에서 Test inference를 진행할 수 있습니다.

Training dataset에서 일부 데이터를 가져와 질문한 결과 다음과 같이 출력되었습니다.

{
    "inputs": "<s>[INST] How do bipedal creatures walk? [/INST] ",
    "parameters": {
        "max_new_tokens": 256,
        "top_p": 0.9,
        "temperature": 0.6,
        "decoder_input_details": true,
        "details": true
    }
}{"question":"How do bipedal creatures walk?","correct_answer":"on two legs"}{
    "body": [
        {
            "generated_text": "<s>[INST] How do bipedal creatures walk? [/INST]  Bipedal creatures, such as humans and some animals, walk by using their legs and feet to support their body weight and move around. Here is a general overview of the mechanics of bipedal walking:\n\n1. Posture: Bipedal creatures stand upright on two legs, with the body balanced over the feet. The legs are straight and the feet are flat on the ground.\n2. Weight distribution: The body weight is distributed evenly between both legs, with the majority of the weight on the balls of the feet.\n3. Muscle contractions: The muscles in the legs, particularly the quadriceps and hamstrings, contract and relax to move the legs forward. The muscles in the feet also contract and relax to help maintain balance and stability.\n4. Foot strike: The feet make contact with the ground in a specific way, with the heel striking first, followed by the ball of the foot. This helps to absorb the impact of each step and maintain balance.\n5. Ankle movement: The ankles move through a range of motion to help distribute the weight of the body and maintain balance.\n6. Knee movement", 

...
}

테스트한 결과 질문에 대한 응답은 하였지만 조금 광범위하게 학습 데이터를 출력했음을 확인하였습니다. 따라서 만약 더 구체적인 결과나 혹은 특정한 형식의 출력을 원한다면 프롬프트를 제어하거나, 더 구체적인 데이터를 출력해야합니다.