본문으로 이동

JAX

위키백과, 우리 모두의 백과사전.
JAX
개발자구글, 엔비디아[1]
미리보기 버전
v0.4.31 / 2024년 7월 30일(10개월 전)(2024-07-30)
저장소(영어) jax - 깃허브
프로그래밍 언어파이썬, C++
운영 체제리눅스, macOS, 윈도우
플랫폼파이썬, NumPy
크기9.0 MB
종류머신러닝
라이선스아파치 라이선스 2.0
웹사이트jax.readthedocs.io/en/latest/ 위키데이터에서 편집하기

JAX는 고성능 수치 계산 및 대규모 머신러닝에 사용되는 파이썬 라이브러리이다. 구글에서 개발했으며 엔비디아와 다른 커뮤니티 기여자들의 기여를 받았다.[2][3][4]

수정된 버전의 autograd(함수의 미분을 통해 기울기 함수를 자동으로 얻는 기능)와 OpenXLA의 XLA (Accelerated Linear Algebra)를 결합한 것으로, NumPy의 구조와 워크플로우를 최대한 따르도록 설계되었으며, 텐서플로파이토치와 같은 기존 프레임워크와 함께 작동한다.[5][6] JAX의 주요 기능은 다음과 같다.[7]

  1. CPU, GPU, TPU에서 로컬 또는 분산 설정으로 실행되는 계산에 NumPy와 유사한 통합 인터페이스 제공.
  2. 오픈 소스 머신러닝 컴파일러 생태계인 Open XLA를 통한 내장 Just-In-Time (JIT) 컴파일.
  3. 자동 미분 변환을 통한 기울기의 효율적인 평가.
  4. 입력 배치들을 나타내는 배열에 효율적으로 매핑되도록 자동 벡터화.

grad

[편집]

아래 코드는 grad 함수의 자동 미분을 보여준다.

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x):
    return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)

마지막 줄은 다음을 출력해야 한다.

0.19661194

jit

[편집]

아래 코드는 jit 함수의 융합을 통한 최적화를 보여준다.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

jit_cube (17번째 줄)의 계산 시간은 cube (16번째 줄)보다 눈에 띄게 짧아야 한다. 7번째 줄의 값을 늘리면 차이는 더욱 심화된다.

vmap

[편집]

아래 코드는 vmap 함수의 벡터화를 보여준다.

# imports
from jax import vmap partial
import jax.numpy as jnp

# define function
def grads(self, inputs):
    in_grad_partial = jax.partial(self._net_grads, self._net_params)
    grad_vmap = jax.vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads
벡터화된 덧셈 시연 영상

pmap

[편집]

아래 코드는 행렬 곱셈을 위한 pmap 함수의 병렬 처리를 보여준다.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

마지막 줄은 다음과 같은 값을 출력해야 한다.

[1.1566595 1.1805978]

같이 보기

[편집]

외부 링크

[편집]

참고 자료

[편집]
  1. “jax/AUTHORS at main · jax-ml/jax”. 《깃허브. 2024년 12월 21일에 확인함. 
  2. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022년 6월 18일), “JAX: Autograd and XLA”, 《Astrophysics Source Code Library》 (Google), Bibcode:2021ascl.soft11002B, 2022년 6월 18일에 원본 문서에서 보존된 문서, 2022년 6월 18일에 확인함 
  3. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018년 2월 2일). “Compiling machine learning programs via high-level tracing” (PDF). 《MLsys》: 1–3. 2022년 6월 21일에 원본 문서 (PDF)에서 보존된 문서. 
  4. “Using JAX to accelerate our research”. 《www.deepmind.com》 (영어). 2022년 6월 18일에 원본 문서에서 보존된 문서. 2022년 6월 18일에 확인함. 
  5. Lynley, Matthew. “Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta”. 《Business Insider》 (미국 영어). 2022년 6월 21일에 원본 문서에서 보존된 문서. 2022년 6월 21일에 확인함. 
  6. “Why is Google's JAX so popular?”. 《Analytics India Magazine》 (미국 영어). 2022년 4월 25일. 2022년 6월 18일에 원본 문서에서 보존된 문서. 2022년 6월 18일에 확인함. 
  7. “Quickstart — JAX documentation”. 

틀:Differentiable computing