강화학습/ML-Agent Unity

Agent Script - Dodge

유니티런 2025. 4. 28. 14:56

네임스페이스

using System.Linq; //쿼리, 데이터조회 추출기능
using Unity.MLAgents.Actuators; //파이썬 코드에서 행동을 받고 에이전트를 제어하는 기능
using Unity.MLAgents.Sensors; //관측정보를 사용하기 위한 기능

이번 스크립트는 Agent로 부터 상속 받습니다.

변수

    public DArea area;  // Area 스크립트
    Rigidbody RbAgent;  // 에이전트의 리지드바디
    float speed = 3f; // 에이전트 속더
    Vector3 centerPos = Vector3.zero;  //에이전트 위치
    public float DecisionWaitingTime = 0.05f;  // 함수를 호출하기까지 기다리는 시간
    float m_currentTime = 0f;  //Decision Requester를 위한 시간기록

유니티 머신러닝 에이전트에서는 DecisionRequest 호출을 받아야 에이전트 행동수행이 가능합니다. 공식적으로 지원하는 DecisionRequest를 사용해도 되지만에피소드종료후 다음 에피소드의 첫번째 상태가 나오지 않는 버그가 있어 코드내에서 DecisionRequest관련 내용을 작성합니다.

 

 

Initialize()함수는  Agent에서 Start()와 같은 함수로 초기화에 사용됩니ㅏㄷ.

OnEpisodeBegin()함수도 에이전트에서 제공되는 함수로 에피소드가 실행될때 호출되는 함수입니다.

먼저 Area.ResetEnv()호출 게임환경을 초기화하고 에이전트의 상태를 설정합니다.

3x3=9개의 환경을 테스트 하므로 agent의 위치는 local로 제어합니다.

 

CollectObservations(VectorSensor sensor)함수는 에이전트에서 제공하는 함수로 벡터관측정보를 수집하기 위해 사용하는 함수 입니다.

닷지환경에서는 RayCast 광선을 기반으로 거리를 측정합니다.rayCount가 40개 이므로 360/40=9도 발사각을 설정합니다.

레이를 쏴서 충돌을 감지하기 위해서는 우선 2개의 구조체를 마련하고

      Ray ray;  광선을 정의 하기 위한 구조체

      RaycastHit hit; 충돌을 탐지하기 위한 구조체

광선을 만들고 충돌을 감지합니다. 

ray = new Ray(gameObject.transform.position, new Vector3(Mathf.Cos(angle), 0, Mathf.Sin(angle)));
if (Physics.Raycast(ray, out hit))  //ray에 의한 충돌정보는 out hit에 담깁니다.

out 키워드 : out은 무엇일까. C#에서는 out을 직접 매개변수의 값을 바꿀 수 있는 매개변수 한정자로 사용한다. 함수의 매개변수는 복사되서 전달되므로 함수내의 변경은 로컬적으로만 이루어 지는데 이걸 바뀌게 해준다  c++의 포인터나 레페런스참조자 같은것.

출처: https://kimasill.tistory.com/entry/Unity-레이캐스트Raycast-사용-시-out-키워드의-의미 [Logic:티스토리]

 

[Unity] 레이캐스트 (Physics.Raycast)

Unity에서 광선을 쏘아 충돌체를 감지할 수 있는 Physics.Raycast알아보고, 이를 디버깅할 때 사용할 수 있는 Draw.Debug에 대해 간단히 알아보자.  개념시작 지점에서 특정 방향으로 씬의 모든 충돌체

sikpang.tistory.com

public override void CollectObservations(VectorSensor sensor)
{
    RaycastHit hit;
    Ray ray;

    float angle;
    int raycount = 40;
    List<Vector3> debugRay = new List<Vector3>();

    for (int i = 0; i < raycount; i++)
    {
        angle = i * 2.0f * Mathf.PI / raycount;
        ray = new Ray(gameObject.transform.position, new Vector3(Mathf.Cos(angle), 0, Mathf.Sin(angle)));

        if (Physics.Raycast(ray, out hit))
        {
            sensor.AddObservation(hit.distance); // 충돌체와의 거리 정보를 센서에 추가합니다.

            if (hit.collider.gameObject.name == "ArenaWalls")
            {  //레이가 벽과 충돌한 경우 (0,0)을 벡터 정보로 추가
                sensor.AddObservation(new Vector2(0, 0));
            }
            else
            {  //벽이 아니면 공이고 리지드바디의 속도 정보를 정보고 추가합니다.
                Rigidbody rig = hit.collider.gameObject.GetComponent<Rigidbody>();
                var vel = new Vector2(rig.velocity.x, rig.velocity.z);
                sensor.AddObservation(vel);
            }
            debugRay.Add(hit.point); //충돌지점을 디버그레이에 추가합니다.
        }
    }

    sensor.AddObservation(RbAgent.velocity.x);  //에이전트의 x속도 추가
    sensor.AddObservation(RbAgent.velocity.z); //에이전트의 z속도 추가

    for (int i = 0; i < debugRay.Count; i++)
    { //레이의 시각화
        Debug.DrawRay(gameObject.transform.position, debugRay[i] - gameObject.transform.position, Color.green);
    }
}

OnActionReceived()는 ML_Agent에서 제공하는 함수로 파이썬으로 부터 행동을 받고 에이전트의 행동 보상 게임종료 조건등을 결정하는 함수 입니다. 

파이썬으로 부터 받은 이산적인 행동에대한 정보 actionBuffer.DiscreteAction[0] 의 값을 action에 설정합니다.

siwtch문을 통해 action에 따라 상하좌우로 움직여 주고 해당사항이 없을경우 정지합니다. velcocitychange모드는 가속없이 직접 속도를 바꿔주는 모드입니다.

public override void OnActionReceived(ActionBuffers actionBuffers)
{
    var action = actionBuffers.DiscreteActions[0];
    Vector3 force = Vector3.zero;

    switch (action)
    {
        case 1: force = new Vector3(-1, 0, 0) * speed; break;
        case 2: force = new Vector3(0, 0, 1) * speed; break;
        case 3: force = new Vector3(0, 0, -1) * speed; break;
        case 4: force = new Vector3(1, 0, 0) * speed; break;
        default: force = new Vector3(0, 0, 0) * speed; break;
    }
    RbAgent.AddForce(force, ForceMode.VelocityChange);

Physics.OverlapBox(위치, 크기)를 이용 에이전트와 겹치는 게임오브젝트중 Balls을 찾아 냅니다. Block의 길이가 0이 아닌 경우 즉 볼과 충돌했을경우 Reward를 -1로 설정하고 에피소드를 끝냅니다. 아닌경우 Reward로 0.1를 부여합니다.

코드중 =>는 C#에서 람다식 또는 ExpressionBodiedMember라는데 조건식 같은것 같다. block[]안에 Tag가 "ball'이 있는을 where()로 쿼리해서 그 배열의 길이가 0이 아니라는 것.

block.Where(Col => Col.gameObject.CompareTag("ball")).ToArray().Length != 0

 Collider[] block = Physics.OverlapBox(gameObject.transform.position, Vector3.one * 0.5f);

 if (block.Where(Col => Col.gameObject.CompareTag("ball")).ToArray().Length != 0)
 {
     SetReward(-1f);
     EndEpisode();
 }
 else
 {
     SetReward(0.1f);
 }

Heuristic()함수는 에이전트에서 제공되는 함수로 특정규칙대로 에이전트를 행동하도록 설정합니다. 보통 유저의 키 입력에 따라 에이전트를 제어하도록하며 유저의도대로 잘 동작하는지 디버깅할수 있습니다. 입력키코드가 WASD일경우 discreteActionOut에 해당되는 값을 넣어주네요

 public override void Heuristic(in ActionBuffers actionsOut)
 {
     var discreteActionsOut = actionsOut.DiscreteActions;
     discreteActionsOut[0] = 0;

     if (Input.GetKey(KeyCode.W))
     {
         discreteActionsOut[0] = 1;
     }
     if (Input.GetKey(KeyCode.D))
     {
         discreteActionsOut[0] = 2;
     }
     if (Input.GetKey(KeyCode.A))
     {
         discreteActionsOut[0] = 3;
     }
     if (Input.GetKey(KeyCode.S))
     {
         discreteActionsOut[0] = 4;
     }
 }

 WaitTimeInference()에서는 requestDecision()함수를 호출하여 특정 주기에 따라 에이전트 행동이 실제적으로 수행되도록 합니다. 이를 위해 초기 변수 설정에서 decsionWaitingTime과 m_CurrentTime값을 설정했습니다.

public void WaitTimeInference(int action)
{
    if (Academy.Instance.IsCommunicatorOn)
    {
        RequestDecision();
    }
    else
    {
        if (m_currentTime >= DecisionWaitingTime)
        {
            m_currentTime = 0f;
            RequestDecision();
        }
        else
        {
            m_currentTime += Time.fixedDeltaTime;
        }
    }
}

커뮤니케이션이 되고 있지 않은 경우 직접 시간 계산을 해서 일정시간 주기바다 호출합니다.

전체코드

using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents;
using Unity.MLAgentsExamples;
using UnityEngine;

public class DAgent : Agent
{
    public DArea area;
    Rigidbody RbAgent;
    float speed = 3f;
    Vector3 centerPos = Vector3.zero;

    public float DecisionWaitingTime = 0.05f;
    float m_currentTime = 0f;

    public override void Initialize()
    {
        RbAgent = gameObject.GetComponent<Rigidbody>();
        centerPos = gameObject.transform.position;

        Academy.Instance.AgentPreStep += WaitTimeInference;
    }

    public override void OnEpisodeBegin()
    {
        area.ResetEnv();

        transform.localPosition = centerPos;
        RbAgent.velocity = Vector3.zero;
        RbAgent.angularVelocity = Vector3.zero;
    }

    public void SetAgentSpeed(float speed_)
    {
        speed = speed_;
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        RaycastHit hit;
        Ray ray;

        float angle;
        int raycount = 40;
        List<Vector3> debugRay = new List<Vector3>();

        for (int i = 0; i < raycount; i++)
        {
            angle = i * 2.0f * Mathf.PI / raycount;
            ray = new Ray(gameObject.transform.position, new Vector3(Mathf.Cos(angle), 0, Mathf.Sin(angle)));

            if (Physics.Raycast(ray, out hit))
            {
                sensor.AddObservation(hit.distance);

                if (hit.collider.gameObject.name == "ArenaWalls")
                {
                    sensor.AddObservation(new Vector2(0, 0));
                }
                else
                {
                    Rigidbody rig = hit.collider.gameObject.GetComponent<Rigidbody>();
                    var vel = new Vector2(rig.velocity.x, rig.velocity.z);
                    sensor.AddObservation(vel);
                }

                debugRay.Add(hit.point);
            }
        }

        sensor.AddObservation(RbAgent.velocity.x);
        sensor.AddObservation(RbAgent.velocity.z);

        for (int i = 0; i < debugRay.Count; i++)
        {
            Debug.DrawRay(gameObject.transform.position, debugRay[i] - gameObject.transform.position, Color.green);
        }
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        var action = actionBuffers.DiscreteActions[0];
        Vector3 force = Vector3.zero;

        switch (action)
        {
            case 1: force = new Vector3(-1, 0, 0) * speed; break;
            case 2: force = new Vector3(0, 0, 1) * speed; break;
            case 3: force = new Vector3(0, 0, -1) * speed; break;
            case 4: force = new Vector3(1, 0, 0) * speed; break;
            default: force = new Vector3(0, 0, 0) * speed; break;
        }

        RbAgent.AddForce(force, ForceMode.VelocityChange);

        Collider[] block = Physics.OverlapBox(gameObject.transform.position, Vector3.one * 0.5f);

        if (block.Where(Col => Col.gameObject.CompareTag("ball")).ToArray().Length != 0)
        {
            SetReward(-1f);
            EndEpisode();
        }
        else
        {
            SetReward(0.1f);
        }
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var discreteActionsOut = actionsOut.DiscreteActions;
        discreteActionsOut[0] = 0;

        if (Input.GetKey(KeyCode.W))
        {
            discreteActionsOut[0] = 1;
        }
        if (Input.GetKey(KeyCode.D))
        {
            discreteActionsOut[0] = 2;
        }
        if (Input.GetKey(KeyCode.A))
        {
            discreteActionsOut[0] = 3;
        }
        if (Input.GetKey(KeyCode.S))
        {
            discreteActionsOut[0] = 4;
        }
    }

    public void WaitTimeInference(int action)
    {
        if (Academy.Instance.IsCommunicatorOn)
        {
            RequestDecision();
        }
        else
        {
            if (m_currentTime >= DecisionWaitingTime)
            {
                m_currentTime = 0f;
                RequestDecision();
            }
            else
            {
                m_currentTime += Time.fixedDeltaTime;
            }
        }
    }
}