02 - Environment Implementation

The environment is where the economics becomes executable. In this project, that translation had to satisfy two constraints at once. It had to remain faithful to the strategic structure of Cournot and Bertrand competition, and it had to run fast enough to support large experimental grids across algorithms, information structures, firm counts, and random seeds. That second constraint is why the multi-agent environments were implemented around JAX.

The design is not “everything in JAX” and not “pure Python with some vectorization.” It is a hybrid architecture. The public environment remains a plain Python class with a simple RL interface: reset(), step(), benchmarks(), and counterfactual helpers used in evaluation. Inside that class, the market-clearing arithmetic is delegated to small JAX kernels that are compiled once and then reused. This separation keeps the code readable at the interface level while moving the heavy numerical work into compiled functions.

Why JAX Here

The computational bottleneck appears when all firms learn simultaneously. In that setting, every training step requires repeated market clearing, reward computation for all agents, and counterfactual profit evaluation for equilibrium diagnostics. The same arithmetic is executed thousands or millions of times.

JAX is useful in exactly that regime. The environment dynamics are deterministic array transformations: clip actions, aggregate quantities or prices, compute market outcomes, and return profits. These are the kinds of functions that benefit from jit compilation because the control flow is simple and the shape structure is stable across steps. The point is not only raw speed. It is reducing the cost of repeated evaluation enough that the experimental design remains broad instead of collapsing into a handful of seeds or short runs.

This matters methodologically. In decentralized learning, the distinction between convergence, cycling, and accidental coordination often appears only after many replications. If simulation is expensive, the first thing that gets cut is replication. A JAX-based environment helps preserve statistical width in the experiment rather than forcing narrow evidence from a small computational budget.

Hybrid Architecture

The implemented architecture separates the environment into two layers.

At the outer layer, the environment is an ordinary Python object. It validates configuration, stores analytical benchmarks, exposes the RL API, and caches the previous market outcome so each agent can receive the correct observation. This keeps the interface easy to test and easy to extend. The trainer does not need to know anything about compilation details.

flowchart TD
    A[Trainer] --> B[Python Environment Class]
    B --> C[Config validation and benchmark loading]
    B --> D[reset and step public API]
    B --> E[Cached last-step state<br/>prices quantities profits observations]
    D --> F[JAX kernel call]
    F --> G[jit-compiled market clearing]
    F --> H[vmapped counterfactual profits]
    G --> E
    H --> E
    E --> I[Observation assembly by info structure]
    I --> A

At the inner layer, the economically meaningful but numerically repetitive parts are compiled. In the Cournot environment, the JAX step kernel takes an action vector , clips it to the feasible interval, computes total output,

computes price,

and returns profits,

The Bertrand environment follows the same pattern. The kernel takes a price vector, evaluates logit demand shares, and returns profits for all firms in one compiled pass. In both environments, the class method step() converts Python inputs into JAX arrays, calls the compiled kernel, and then stores the resulting arrays as NumPy state for the observation interface.

This is a deliberate compromise. A fully JAX-native environment would push more of the trainer and state handling into a functional style. That can be elegant, but it also raises the cost of development and debugging. Here the project only compiles the parts that are repeated often enough to matter.

JIT Compilation and Function Caches

The key implementation choice is that compilation happens once per market specification, not once per training step. Each environment module maintains small module-level caches for compiled functions. In Cournot, the step kernel is keyed by the tuple . In Bertrand, the key is the corresponding market tuple for the logit demand system. If a later experiment reuses the same market configuration, the already-compiled function can be reused immediately.

That cache matters because compilation is not free. Without it, the code would pay setup cost repeatedly and destroy much of the performance gain. With it, the expensive part happens only at first use. After that, each period of the game becomes a compiled array computation rather than a fresh interpretation of Python arithmetic.

The same pattern appears in the counterfactual-profit routines used for CE-gap evaluation. The project does not compute these deviations with explicit Python loops over candidate actions. Instead, it compiles a single-profit function and vectorizes it with vmap over an action grid. For a fixed rival profile, the environment can evaluate many unilateral deviations in one JAX call. This is especially useful because equilibrium diagnostics are not incidental in this project. They are part of the empirical argument.

Observations Without Recomputing the Market

One subtle implementation decision is that observations are assembled from cached last-step outcomes instead of rerunning economic calculations when _obs_for() is called. After each compiled step, the environment stores arrays such as profits, prices, quantities, and rivals’ actions. Observation construction then becomes a simple read from that cached state.

This avoids two problems. First, it prevents duplicated numerical work. Second, it guarantees that all agents’ observations come from the exact same realized market outcome. In a multi-agent setting, that consistency matters. If the observation layer were allowed to recompute quantities or prices independently, small numerical mismatches could create hard-to-diagnose bugs in learning dynamics.

The information structures are then implemented as slices of that cached outcome. In Cournot, an agent may observe only its own profit, the market price, or the rivals’ quantities. In Bertrand, it may observe only profit, its own last price, or the full price vector. The economic environment is the same in each case. What changes is the interface through which the agent learns from it.

Why the Design Is Selective

The JAX implementation is concentrated where numerical repetition is highest, not forced uniformly onto every part of the codebase. That asymmetry is intentional.

Some environment components are lightweight and stateful in ways that are easier to express with ordinary Python structures. History buffers, interface validation, benchmark storage, and observation assembly all remain simpler and more transparent outside a fully functional JAX style. By contrast, market clearing, profit computation, and counterfactual evaluation are repeated often enough that compilation pays for itself quickly.

The project therefore uses JAX where throughput matters and keeps plain Python where it improves clarity without creating meaningful cost. That balance is part of the implementation strategy, not a temporary compromise.

Practical Advantages of the JAX Design

The first advantage is throughput. Market clearing and reward computation happen for all firms in a single compiled kernel call. That reduces per-step overhead and makes longer runs and larger sweeps feasible.

The second advantage is vectorized counterfactual analysis. Metrics such as CE-gap require repeated evaluation of unilateral deviations. Because those deviations are vmapped, the project can treat diagnostic computation as part of the standard workflow rather than an expensive afterthought.

The third advantage is consistency across environments. Cournot and Bertrand use different economics, but they share the same design logic: a thin Python environment API, compiled numerical kernels, cached previous-step outcomes, and explicit benchmark access. That uniformity lowers the cognitive cost of extending the project to new market forms.

The fourth advantage is reproducibility under a heavy experimental load. Faster evaluation does not only save time. It reduces the pressure to simplify the experiment in ways that would weaken inference. If broad sweeps over seeds and information structures remain affordable, the claims drawn from the results are better grounded.

There is also a less visible benefit: the implementation forces a clean separation between model primitives and training logic. Demand systems, cost rules, and benchmark calculations remain explicit. The trainer asks the environment for outcomes; it does not reimplement the economics. That separation is good software engineering, but it is also good research practice because it keeps theoretical assumptions localized and inspectable.

Limits and Tradeoffs

This design is not maximalist. The environment object itself is not JIT-compiled end to end. The trainer still loops in Python over steps and over agents. State is transferred back into NumPy after each compiled call so the rest of the codebase can remain simple. That means the implementation does not extract every possible speed gain from JAX.

But the tradeoff is defensible. The objective here is not to build a benchmark-optimized simulator at any readability cost. It is to support a research workflow where correctness, transparency, and extensibility matter alongside speed. Compiling the inner kernels captures most of the performance gain while avoiding a full rewrite of the training architecture in a purely functional style.

That is why the current implementation is best understood as a research-oriented JAX design rather than a fully JAX-native system. It uses compilation where the economics repeats, keeps the public API conventional, and leaves room for later extensions if larger-scale experiments demand a more aggressive functional rewrite.

References