Jax 生态再添新库:DeepMind 开源 Haiku、RLax
Python高校
共 11507字,需浏览 24分钟
·
2021-09-29 10:55
点击 “凹凸域”,马上关注
真爱,请置顶或星标
来源:机器之心
Jax 是一个优秀的代码库,在进行科学计算的同时能够自动微分,还有 GPU、TPU 的性能加速加持。但是 Jax 的生态还不够完善,使用者相比 TF、PyTorch 少得多。近日,DeepMind 开源了两个基于 Jax 的新库,给这个生态注入了新的活力。
自定义你的模块
class MyLinear(hk.Module):
def __init__(self, output_size, name=None):
super(MyLinear, self).__init__(name=name)
self.output_size = output_size
def __call__(self, x):
j, k = x.shape[-1], self.output_size
w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
return jnp.dot(x, w) + b
def forward_fn(x):
model = MyLinear(10)
return model(x)
# Turn `forward_fn` into an object with `init` and `apply` methods.
forward = hk.transform(forward_fn)
x = jnp.ones([1, 1])
# When we run `forward.init`, Haiku will run `forward(x)` and collect initial# parameter values. Haiku requires you pass a RNG key to `init`, since parameters# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)
# When we run `forward.apply`, Haiku will run `forward(x)` and inject parameter# values from the `params` that are passed as the first argument. We do not require# an RNG key by default since models are deterministic. You can (of course!) change# this using `hk.transform(f, apply_rng=True)` if you prefer:
y = forward.apply(params, x)
def forward(x, is_training):
net = hk.nets.ResNet50(1000)
return net(x, is_training)
forward = hk.transform_with_state(forward)
# The `init` function now returns parameters **and** state. State contains# anything that was created using `hk.set_state`. The structure is the same as# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)
# The apply function now takes both params **and** state. Additionally it will# return updated values for state. In the resnet example this will be the# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)
def loss_fn(inputs, labels):
logits = hk.nets.MLP([8, 4, 2])(x)
return jnp.mean(softmax_cross_entropy(logits, labels))
loss_obj = hk.transform(loss_fn)
# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_obj.init(rng, sample_image, sample_label)
# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)
def make_superbatch():
"""Constructs a superbatch, i.e. one batch of data per device."""
# Get N batches, then split into list-of-images and list-of-labels.
superbatch = [next(input_dataset) for _ in range(num_devices)]
superbatch_images, superbatch_labels = zip(*superbatch)
# Stack the superbatches to be one array with a leading dimension, rather than
# a python list. This is what `jax.pmap` expects as input.
superbatch_images = np.stack(superbatch_images)
superbatch_labels = np.stack(superbatch_labels)
return superbatch_images, superbatch_labels
def update(params, inputs, labels, axis_name='i'):
"""Updates params based on performance on inputs and labels."""
grads = jax.grad(loss_obj.apply)(params, inputs, labels)
# Take the mean of the gradients across all data-parallel replicas.
grads = jax.lax.pmean(grads, axis_name)
# Update parameters using SGD or Adam or ...
new_params = my_update_rule(params, grads)
return new_params
# Run several training updates.
for _ in range(10):
superbatch_images, superbatch_labels = make_superbatch()
params = jax.pmap(update, axis_name='i')(params, superbatch_images,
superbatch_labels)
def build_network(num_actions: int) -> hk.Transformed:
def q(obs):
flatten = lambda x: jnp.reshape(x, (-1,))
network = hk.Sequential(
[flatten, nets.MLP([FLAGS.hidden_units, num_actions])])
return network(obs)
return hk.transform(q)
def main_loop(unused_arg):
env = catch.Catch(seed=FLAGS.seed)
rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed))
# Build and initialize Q-network.
num_actions = env.action_spec().num_values
network = build_network(num_actions)
sample_input = env.observation_spec().generate_value()
net_params = network.init(next(rng), sample_input)
# Build and initialize optimizer.
optimizer = optix.adam(FLAGS.learning_rate)
opt_state = optimizer.init(net_params)
@jax.jit
def policy(net_params, key, obs):
"""Sample action from epsilon-greedy policy."""
q = network.apply(net_params, obs)
a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q)
return q, a
@jax.jit
def eval_policy(net_params, key, obs):
"""Sample action from greedy policy."""
q = network.apply(net_params, obs)
return rlax.greedy().sample(key, q)
@jax.jit
def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t):
"""Update network weights wrt Q-learning loss."""
def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):
q_tm1 = network.apply(net_params, obs_tm1)
td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
return rlax.l2_loss(td_error)
dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1, r_t,
discount_t, q_t)
updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
net_params = optix.apply_updates(net_params, updates)
return net_params, opt_state
print(f"Training agent for {FLAGS.train_episodes} episodes...")
更多内容请关注我们 ↓
看完本文有收获?
评论