본문 바로가기
Dev/C++

Template Meta Programming으로 matrix 라이브러리 만들기 (3)

by Jino Park 2022. 7. 30.
반응형

이전 포스팅에서는 Expression tree를 위한 MatrixExpr, 기본적인 행렬을 표현하는 Matrix, 행렬 간 덧셈을 표현하는 MatrixSum을 구현 후 불필요한 복사가 없는 행렬 덧셈을 구현하였다. Expression tree를 만드는 과정과 이를 evaluation 하는 과정을 분리할 수 있었고, 이를 통해 실제 값이 필요한 시점에 계산을 할 수 있다 (lazy evaluation). 이와 비슷한 방법으로 행렬 간 뺄셈, element-wise 한 곱셈 (아다마르 곱) 등을 구현할 수 있을 것이다.

하지만 lazy evaluation은 모든 경우에 대해 항상 효율적이지 못하다. 이번 포스팅에서는 그 대표적인 예 중 하나인 행렬간 곱셈을 구현하면서 자세히 알아보자.

1. MatrixMultNaive 클래스

이전 포스팅에서 MatrixSum 클래스를 만든 것 처럼 똑같이 MatrixExpr을 상속받아 MatrixMultNaive 클래스를 만들어보자. MatrixSum과 마찬가지로 좌항/우항의 expression을 각각 E1, E2로 가져오고, static polymorphism을 위해 CRTP를 사용하였다. 한 가지 차이점은 requires 문인데, 행렬의 곱셈에서 좌항의 열의 수와 우항의 행의 수가 같아야 하는 것을 나타냈다.

template<typename E1, typename E2>
    requires(E1::col == E2::row)
struct MatrixMultNaive : public MatrixExpr<MatrixMultNaive<E1, E2>> {
	// ...
};

row, col, Type을 정의해보자. 곱해진 행렬은 좌항(E1)의 행, 그리고 우항(E2)의 열의 크기를 갖는다. Type은 우선 좌항의 Type을 사용하자.

using Type = typename E1::Type;
static constexpr Index row = E1::row;
static constexpr Index col = E2::col;

그리고 operand를 받는 생성자, 그리고 그 레퍼런스를 저장하는 멤버 x, y를 선언/구현했다.

MatrixMultNaive(const E1 &x, const E2 &y) : x{x}, y{y} {}

const E1 &x;
const E2 &y;

이제 특정 원소의 값을 계산하는 elem()함수를 구현해보자. 좌항의 행, 우항의 열을 dot product를 한 것과 같다.

// Static polymorphism implementation of MatrixExpr
inline auto elem(Index r, Index c) const {
    Type ret = static_cast<Type>(0);
    for (Index i = 0; i < E1::col; i++) {
        ret += x.elem(r, i) * y.elem(i, c);
    }
    return ret;
}

전체 코드는 다음과 같다. MatrixMultNaive를 생성하는 operator*() 함수도 operator+*()와 같이 만들 수 있다.

template<typename E1, typename E2>
    requires(E1::col == E2::row)
struct MatrixMultNaive : public MatrixExpr<MatrixMultNaive<E1, E2>> {
    using Type = typename E1::Type;
    static constexpr Index row = E1::row;
    static constexpr Index col = E2::col;

    MatrixMultNaive(const E1 &_x, const E2 &_y) : x{_x}, y{_y} {}

    // Static polymorphism implementation of MatrixExpr
    inline auto elem(Index r, Index c) const {
        Type ret = static_cast<Type>(0);
        for (Index i = 0; i < E1::col; i++) {
            ret += x.elem(r, i) * y.elem(i, c);
        }
        return ret;
    }

    const E1 &x;
    const E2 &y;
};

template<typename E1, typename E2>
    requires(E1::col == E2::row)
MatrixMultNaiveE1, E2> operator*(const MatrixExpr<E1> &x, const MatrixExpr<E2> &y) {
    return MatrixMultNaive<E1, E2>(static_cast<const E1 &>(x), static_cast<const E2 &>(y));
}

2. 결과

계산이 잘 되는 것을 확인할 수 있다.

3. MatrixMultNaive의 단점

i행 j열의 element를 계산하는 과정을 생각해보자. 이때 필요한 것은 좌항(E1)의 i번째 행, 그리고 우항(E2)의 j번째 행이다. 

i번째 행과 j번째 열의 dot product가 i, j번째 element의 값이다.

우리는 이를 elem() 함수에서 각 행과 열의 element를 재귀적으로 evaluation한 후, 이를 곱한 값을 더하는 방법으로 구현하였다. 이제 i, j+1번째 element를 구하는 경우를 다시 생각해보자. 앞서 이미 (i, j) element를 계산할 때 이미 i번째 항을 계산했지만, 그 결과를 따로 저장하지 않았기 때문에 좌항의 i번째 행을 중복으로 계산하게 된다. 중복되는 계산은 리소스의 낭비로 이어지게 되고, 이는 matrix의 크기가 커질수록, 그리고 matrix가 줄지어 곱해질수록 심해질 것이다.

i번째 행과 j+1번째 열의 dot product가 i, j+1번째 element의 값이다. 앞서 이미 좌항의 i번째 행을 계산했음에도 불구하고 다시 한 번 동일한 계산을 수행한다.

4. MatrixMult 클래스

앞서 작성한대로, lazy evaluation이 효율적이지 않은 경우 중 하나가 바로 행렬 간 곱셈이고, 그 이유는 중간 결과를 저장하는 것을 고려하지 않기 때문이다. 이를 피하기 위해 operand를 미리 evaluation한 후 이를 캐싱하면 불필요한 계산을 수행하는 것을 막을 수 있을 것이다.  이를 위해선 기존의 멤버 레퍼런스였던 x, y 대신 Matrix type을 가지게 수정하면 된다.

MatrixMult(const E1 &x, const E2 &y) : x{x}, y{y} {}

const Matrix<Type, E1::row, E1::col> x; // const E1 &x;
const Matrix<Type, E2::row, E2::col> y; // const E2 &y;

이미 Matrix 클래스는 임의의 MatrixExpr을 인자로 갖는 생성자를 가지고 있고 (이전 포스팅 참고), 이제 MatrixMult를 생성하는 시점에 두 operand를 evaluation할 것이다. elem()은 이미 계산이 끝난 operand를 이용해 계산을 하기 때문에 중복돼서 operand evaluation을 하지 않는다.

5. 결과

Naive한 방법으로 구현한 MatrixMultNaive 는 평균적으로 4초가 넘게 걸렸는데, lazy evaluation을 제거한 MatrixMult는 0.04초 정도 걸렸다. (전체 코드)

6. 마무리 & 여담

  • Peanut 을 만들기 위한 기본적인 아이디어/구현 방법을 소개하는 시리즈였다.
  • Eigen, armadillo와도 비교를 해 봤는데 확실히 매우 느리다 (benchmark 브랜치). 그래도 기본적인 pseudo code는 같아 보임
  • binary operation들만 소개를 했지만 이와 같은 방법으로 unary operation도 구현이 가능하다. 이 포스트 작성 시점에 Peanut에는 adjugate, block, cast, cofactor, inverse, minor, submatrix, transpose가 구현이 되어 있다.
  • Expression tree를 생성하는 함수도 들어오는 입력에 따라 다른 output을 내는 방법으로도 최적화가 가능하다. 예를 들어 transpose의 transpose는 T(T(x))가 아니라 x를 return하는 방식으로...
반응형