<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <title>Tony Wang&#39;s blogs</title>
  <icon>https://www.gravatar.com/avatar/7bb8834d839588bd4d8609b964822081</icon>
  <subtitle>Technology Change the World</subtitle>
  <link href="/atom.xml" rel="self"/>
  
  <link href="https://blog.aivgg.com/"/>
  <updated>2026-06-10T16:24:04.383Z</updated>
  <id>https://blog.aivgg.com/</id>
  
  <author>
    <name>Tony Wang</name>
    <email>tony.pfwang@gmail.com</email>
  </author>
  
  <generator uri="https://hexo.io/">Hexo</generator>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/4113725857.html"/>
    <id>https://blog.aivgg.com/posts/4113725857.html</id>
    <published>2026-06-10T16:18:37.542Z</published>
    <updated>2026-06-10T16:24:04.383Z</updated>
    
    <content type="html"><![CDATA[<h3 id="基础"><a href="#基础" class="headerlink" title="基础"></a>基础</h3><p>这段代码，可以看到一个倒立摆在胡乱操作</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"></span><br><span class="line">env_name = <span class="string">&quot;CartPole-v0&quot;</span></span><br><span class="line">env = gym.make(env_name)          <span class="comment"># 导入环境</span></span><br><span class="line"></span><br><span class="line">episodes = <span class="number">10</span></span><br><span class="line"><span class="keyword">for</span> episode <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">1</span>, episodes + <span class="number">1</span>):</span><br><span class="line">    state = env.reset()           </span><br><span class="line">    done = <span class="literal">False</span></span><br><span class="line">    score = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">while</span> <span class="keyword">not</span> done:</span><br><span class="line">        env.render()                           <span class="comment"># 渲染环境</span></span><br><span class="line">        action = env.action_space.sample()     <span class="comment"># 随机采样动作</span></span><br><span class="line">        n_state, reward, done, info = env.step(action)    <span class="comment"># 和环境交互，得到下一个状态，奖励等信息</span></span><br><span class="line">        score += reward                        <span class="comment"># 计算分数</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;Episode : &#123;&#125;, Score : &#123;&#125;&quot;</span>.<span class="built_in">format</span>(episode, score))</span><br><span class="line"></span><br><span class="line">env.close()     <span class="comment"># 关闭窗口</span></span><br></pre></td></tr></table></figure><p>用Stable_baseline3来训练的强化学习模型,可以很好地控制这个环境:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> stable_baselines3 <span class="keyword">import</span> DQN</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.vec_env.dummy_vec_env <span class="keyword">import</span> DummyVecEnv</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.evaluation <span class="keyword">import</span> evaluate_policy</span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"></span><br><span class="line">env_name = <span class="string">&quot;CartPole-v0&quot;</span></span><br><span class="line">env = gym.make(env_name)</span><br><span class="line"><span class="comment"># 把环境向量化，如果有多个环境写成列表传入DummyVecEnv中，可以用一个线程来执行多个环境，提高训练效率</span></span><br><span class="line">env = DummyVecEnv([<span class="keyword">lambda</span> : env])</span><br><span class="line"><span class="comment"># 定义一个DQN模型，设置其中的各个参数</span></span><br><span class="line">model = DQN(</span><br><span class="line">    <span class="string">&quot;MlpPolicy&quot;</span>,                                <span class="comment"># MlpPolicy定义策略网络为MLP网络</span></span><br><span class="line">    env=env, </span><br><span class="line">    learning_rate=<span class="number">5e-4</span>,</span><br><span class="line">    batch_size=<span class="number">128</span>,</span><br><span class="line">    buffer_size=<span class="number">50000</span>,</span><br><span class="line">    learning_starts=<span class="number">0</span>,</span><br><span class="line">    target_update_interval=<span class="number">250</span>,</span><br><span class="line">    policy_kwargs=&#123;<span class="string">&quot;net_arch&quot;</span> : [<span class="number">256</span>, <span class="number">256</span>]&#125;,     <span class="comment"># 这里代表隐藏层为2层256个节点数的网络</span></span><br><span class="line">    verbose=<span class="number">1</span>,                                   <span class="comment"># verbose=1代表打印训练信息，如果是0为不打印，2为打印调试信息</span></span><br><span class="line">    tensorboard_log=<span class="string">&quot;./tensorboard/CartPole-v0/&quot;</span>  <span class="comment"># 训练数据保存目录，可以用tensorboard查看</span></span><br><span class="line">)</span><br><span class="line"><span class="comment"># 开始训练</span></span><br><span class="line">model.learn(total_timesteps=<span class="number">1e5</span>)</span><br><span class="line"><span class="comment"># 策略评估，可以看到倒立摆在平稳运行了</span></span><br><span class="line">mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=<span class="number">10</span>, render=true)</span><br><span class="line"><span class="comment">#env.close()</span></span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;mean_reward:&quot;</span>,mean_reward,<span class="string">&quot;std_reward:&quot;</span>,std_reward)</span><br><span class="line"><span class="comment"># 保存模型到相应的目录</span></span><br><span class="line">model.save(<span class="string">&quot;./model/CartPole.pkl&quot;</span>)</span><br></pre></td></tr></table></figure><h3 id="自定义环境"><a href="#自定义环境" class="headerlink" title="自定义环境"></a>自定义环境</h3><p>需要继承gym.Env类，然后重新其中的方法，配置一定的参数即可，格式如下：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">from</span> gym <span class="keyword">import</span> spaces</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">CustomEnv</span>(gym.Env):</span><br><span class="line">    <span class="string">&quot;&quot;&quot;Custom Environment that follows gym interface&quot;&quot;&quot;</span></span><br><span class="line">    metadata = &#123;<span class="string">&#x27;render.modes&#x27;</span>: [<span class="string">&#x27;human&#x27;</span>]&#125;</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, arg1, arg2, ...</span>):</span><br><span class="line">        <span class="built_in">super</span>(CustomEnv, <span class="variable language_">self</span>).__init__()</span><br><span class="line">        <span class="comment"># Define action and observation space</span></span><br><span class="line">        <span class="comment"># They must be gym.spaces objects</span></span><br><span class="line">        <span class="comment"># Example when using discrete actions:</span></span><br><span class="line">        <span class="variable language_">self</span>.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)</span><br><span class="line">        <span class="comment"># Example for using image as input (channel-first; channel-last also works):</span></span><br><span class="line">        <span class="variable language_">self</span>.observation_space = spaces.Box(low=<span class="number">0</span>, high=<span class="number">255</span>,</span><br><span class="line">                                            shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">step</span>(<span class="params">self, action</span>):</span><br><span class="line">        ...</span><br><span class="line">        <span class="keyword">return</span> observation, reward, done, info</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self</span>):</span><br><span class="line">        ...</span><br><span class="line">        <span class="keyword">return</span> observation  <span class="comment"># reward, done, info can&#x27;t be included</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">render</span>(<span class="params">self, mode=<span class="string">&#x27;human&#x27;</span></span>):</span><br><span class="line">        ...</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">close</span> (<span class="variable language_">self</span>):</span><br><span class="line">        <span class="keyword">pass</span></span><br></pre></td></tr></table></figure><p>主要三个函数需要实现：</p><blockquote><p>reset() 在每个回合最开始时执行，返回当前的观测（observation）</p><p>step(action) 输入 action，智能体执行 action 与环境交互，返回获得的（新的观测、奖励、是否结束、其他）</p><p>可选render(method&#x3D;’human’)&#96; 渲染环境</p></blockquote><ul><li>gym.spaces.Box 任意 shape 的连续空间</li><li>spaces.Discrete 维度为 1，且有 n 个枚举值的空间</li></ul><p>检查环境是否符合gym接口：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> stable_baselines3.common.env_checker <span class="keyword">import</span> check_env</span><br><span class="line"></span><br><span class="line">env = CustomEnv(arg1, ...)</span><br><span class="line"><span class="comment"># It will check your custom environment and output additional warnings if needed</span></span><br><span class="line">check_env(env)</span><br></pre></td></tr></table></figure><p>创建一个让智能体学习如何一直向左边走的1D环境，观测是智能体的当前位置，智能体有两种行为，向左和向右，分别用0和1代表。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br></pre></td><td class="code"><pre><span class="line"></span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">from</span> gym <span class="keyword">import</span> spaces</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">GoLeftEnv</span>(gym.Env):</span><br><span class="line">  <span class="string">&quot;&quot;&quot;</span></span><br><span class="line"><span class="string">  这是一个让智能体学习一直向左走的 1D grid 环境 </span></span><br><span class="line"><span class="string">  &quot;&quot;&quot;</span></span><br><span class="line">  metadata = &#123;<span class="string">&#x27;render.modes&#x27;</span>: [<span class="string">&#x27;console&#x27;</span>]&#125;</span><br><span class="line">  LEFT = <span class="number">0</span></span><br><span class="line">  RIGHT = <span class="number">1</span></span><br><span class="line"></span><br><span class="line">  <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, grid_size=<span class="number">10</span></span>):</span><br><span class="line">    <span class="built_in">super</span>(GoLeftEnv, <span class="variable language_">self</span>).__init__()</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 1D-grid 的大小</span></span><br><span class="line">    <span class="variable language_">self</span>.grid_size = grid_size</span><br><span class="line">    <span class="comment"># agent 初始化在 grid 的最右边</span></span><br><span class="line">    <span class="variable language_">self</span>.agent_pos = grid_size - <span class="number">1</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># 定义 action  observation </span></span><br><span class="line">    <span class="comment"># 离散行为空间: left、 right</span></span><br><span class="line">    n_actions = <span class="number">2</span></span><br><span class="line">    <span class="variable language_">self</span>.action_space = spaces.Discrete(n_actions)</span><br><span class="line">    <span class="comment"># 观测是智能体现在的位置</span></span><br><span class="line">    <span class="variable language_">self</span>.observation_space = spaces.Box(low=<span class="number">0</span>, high=<span class="variable language_">self</span>.grid_size,</span><br><span class="line">                                        shape=(<span class="number">1</span>,), dtype=np.float32)</span><br><span class="line"></span><br><span class="line">  <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self</span>):</span><br><span class="line">    <span class="string">&quot;&quot;&quot;</span></span><br><span class="line"><span class="string">    Important: 观测必须是一个 np.array</span></span><br><span class="line"><span class="string">    :return: (np.array) </span></span><br><span class="line"><span class="string">    &quot;&quot;&quot;</span></span><br><span class="line">    <span class="comment"># Initialize the agent at the right of the grid</span></span><br><span class="line">    <span class="variable language_">self</span>.agent_pos = <span class="variable language_">self</span>.grid_size - <span class="number">1</span></span><br><span class="line">    <span class="comment"># here we convert to float32 to make it more general (in case we want to use continuous actions)</span></span><br><span class="line">    <span class="keyword">return</span> np.array([<span class="variable language_">self</span>.agent_pos]).astype(np.float32)</span><br><span class="line"></span><br><span class="line">  <span class="keyword">def</span> <span class="title function_">step</span>(<span class="params">self, action</span>):</span><br><span class="line">    <span class="keyword">if</span> action == <span class="variable language_">self</span>.LEFT:</span><br><span class="line">      <span class="variable language_">self</span>.agent_pos -= <span class="number">1</span></span><br><span class="line">    <span class="keyword">elif</span> action == <span class="variable language_">self</span>.RIGHT:</span><br><span class="line">      <span class="variable language_">self</span>.agent_pos += <span class="number">1</span></span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">      <span class="keyword">raise</span> ValueError(<span class="string">&quot;Received invalid action=&#123;&#125; which is not part of the action space&quot;</span>.<span class="built_in">format</span>(action))</span><br><span class="line">    <span class="comment"># 如果走到边缘就不能继续走了</span></span><br><span class="line">    <span class="variable language_">self</span>.agent_pos = np.clip(<span class="variable language_">self</span>.agent_pos, <span class="number">0</span>, <span class="variable language_">self</span>.grid_size)</span><br><span class="line">    <span class="comment"># 如果走到最左边代表结束了</span></span><br><span class="line">    done = <span class="built_in">bool</span>(<span class="variable language_">self</span>.agent_pos == <span class="number">0</span>)</span><br><span class="line">    <span class="comment"># 走到最左边就给一个正的 reward</span></span><br><span class="line">    reward = <span class="number">1</span> <span class="keyword">if</span> <span class="variable language_">self</span>.agent_pos == <span class="number">0</span> <span class="keyword">else</span> <span class="number">0</span></span><br><span class="line">    <span class="comment"># 目前没有需要额外输出的信息</span></span><br><span class="line">    info = &#123;&#125;</span><br><span class="line">    <span class="keyword">return</span> np.array([<span class="variable language_">self</span>.agent_pos]).astype(np.float32), reward, done, info</span><br><span class="line"></span><br><span class="line">  <span class="keyword">def</span> <span class="title function_">render</span>(<span class="params">self, mode=<span class="string">&#x27;console&#x27;</span></span>):</span><br><span class="line">    <span class="comment"># 在命令行中渲染</span></span><br><span class="line">    <span class="keyword">if</span> mode != <span class="string">&#x27;console&#x27;</span>:</span><br><span class="line">      <span class="keyword">raise</span> NotImplementedError()</span><br><span class="line">    <span class="comment"># agent is represented as a cross, rest as a dot</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;.&quot;</span> * <span class="variable language_">self</span>.agent_pos, end=<span class="string">&quot;&quot;</span>)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;x&quot;</span>, end=<span class="string">&quot;&quot;</span>)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;.&quot;</span> * (<span class="variable language_">self</span>.grid_size - <span class="variable language_">self</span>.agent_pos))</span><br><span class="line"></span><br><span class="line">  <span class="keyword">def</span> <span class="title function_">close</span>(<span class="params">self</span>):</span><br><span class="line">    <span class="keyword">pass</span></span><br></pre></td></tr></table></figure><p>构建环境和智能体</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> stable_baselines3 <span class="keyword">import</span> PPO, A2C <span class="comment"># DQN coming soon</span></span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.env_util <span class="keyword">import</span> make_vec_env</span><br><span class="line"></span><br><span class="line"><span class="comment"># 构建环境</span></span><br><span class="line">env = GoLeftEnv(grid_size=<span class="number">10</span>)</span><br><span class="line">env = make_vec_env(<span class="keyword">lambda</span>: env, n_envs=<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 训练智能体</span></span><br><span class="line">model = A2C(<span class="string">&#x27;MlpPolicy&#x27;</span>, env, verbose=<span class="number">1</span>).learn(<span class="number">5000</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 测试智能体：</span></span><br><span class="line"><span class="comment"># Test the trained agent</span></span><br><span class="line">obs = env.reset()</span><br><span class="line">n_steps = <span class="number">20</span></span><br><span class="line"><span class="keyword">for</span> step <span class="keyword">in</span> <span class="built_in">range</span>(n_steps):</span><br><span class="line">  action, _ = model.predict(obs, deterministic=<span class="literal">True</span>)</span><br><span class="line">  <span class="built_in">print</span>(<span class="string">&quot;Step &#123;&#125;&quot;</span>.<span class="built_in">format</span>(step + <span class="number">1</span>))</span><br><span class="line">  <span class="built_in">print</span>(<span class="string">&quot;Action: &quot;</span>, action)</span><br><span class="line">  obs, reward, done, info = env.step(action)</span><br><span class="line">  <span class="built_in">print</span>(<span class="string">&#x27;obs=&#x27;</span>, obs, <span class="string">&#x27;reward=&#x27;</span>, reward, <span class="string">&#x27;done=&#x27;</span>, done)</span><br><span class="line">  env.render(mode=<span class="string">&#x27;console&#x27;</span>)</span><br><span class="line">  <span class="keyword">if</span> done:</span><br><span class="line">    <span class="comment"># Note that the VecEnv resets automatically</span></span><br><span class="line">    <span class="comment"># when a done signal is encountered</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;Goal reached!&quot;</span>, <span class="string">&quot;reward=&quot;</span>, reward)</span><br><span class="line">    <span class="keyword">break</span></span><br></pre></td></tr></table></figure><p>也可以是连续动作空间</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># _*_coding:utf-8-*-</span></span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">from</span> sympy <span class="keyword">import</span> *</span><br><span class="line"><span class="keyword">import</span> math</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line">gym.logger.set_level(<span class="number">40</span>)</span><br><span class="line"><span class="comment"># sys.path.append(&#x27;这里写其上层文件见的绝对路径，如&#x27;~/autodl-nas/robot/&#x27;&#x27;)</span></span><br><span class="line"><span class="comment"># import Params</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">RobotEnv</span>(gym.Env):</span><br><span class="line">    <span class="comment"># 初始化参数</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="comment"># 状态空间为18（关节角度）+18（关节角速度）+2（控制方向）+1（控制速度）=39</span></span><br><span class="line">        <span class="variable language_">self</span>.observation_space = gym.spaces.Box(low=-<span class="number">1</span>, high=<span class="number">1</span>, shape=(<span class="number">39</span>,))</span><br><span class="line">        <span class="comment"># 动作空间为18（关节角度）</span></span><br><span class="line">        <span class="variable language_">self</span>.action_space = gym.spaces.Box(low=-<span class="number">1</span>, high=<span class="number">1</span>, shape=(<span class="number">18</span>,))</span><br><span class="line">        <span class="comment"># 附加功能，可选,对应第35行</span></span><br><span class="line">        <span class="comment"># self.reward_fun = Params.PARAMS_ENV[&#x27;reward_fun&#x27;]</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># 获取原始数据，并生成状态，同时更新done</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">get_state</span>(<span class="params">self</span>):</span><br><span class="line">        origin_data = [这里写获取数据的函数]</span><br><span class="line">        tmp_data = [将原始数据按照要求组成状态，注意角度\角速度\方向和速度归一化到[-<span class="number">1</span>,-<span class="number">1</span>],速度可以除以系数归一化到<span class="number">0</span>-<span class="number">1</span>]]</span><br><span class="line">        state = np.array(tmp_data).reshape(<span class="number">1</span>,<span class="number">39</span>)</span><br><span class="line">        <span class="variable language_">self</span>.done = <span class="literal">True</span> <span class="keyword">if</span>[写判断终止条件]<span class="keyword">else</span> <span class="literal">False</span></span><br><span class="line">        <span class="keyword">return</span> state</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 根据强化学习输出，发送控制指令，控制机器人运动</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">move</span>(<span class="params">self, action</span>):</span><br><span class="line">        [这里将动作对应的控制指令发送给仿真环境中的机器人]</span><br><span class="line">        <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># 奖励函数</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">get_reward</span>(<span class="params">self</span>):</span><br><span class="line">        reward = [根据任务需求定义奖励函数,建议三个方面：<span class="number">1</span>、存活时间长短（即能否满足站立并运动的要求）<span class="number">2</span>、方向是否为给定方向<span class="number">3</span>、速度是否为给定速度]</span><br><span class="line">        <span class="comment"># reward = self.reward_fun#如果使用这个，就不需要上面个这句，上面这句就可以放到参数文件中进行定义</span></span><br><span class="line">        <span class="keyword">return</span> reward</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 主程序</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">step</span>(<span class="params">self, action</span>):</span><br><span class="line">        <span class="comment"># 执行动作</span></span><br><span class="line">        <span class="variable language_">self</span>.move(action)</span><br><span class="line">        <span class="comment"># 获取动作对应奖励</span></span><br><span class="line">        reward = <span class="variable language_">self</span>.get_reward()</span><br><span class="line">        <span class="comment"># 获取下一状态</span></span><br><span class="line">        state = <span class="variable language_">self</span>.get_state()</span><br><span class="line">        <span class="comment"># 返回</span></span><br><span class="line">        <span class="keyword">return</span> state, reward, <span class="variable language_">self</span>.done, &#123;&#125;</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 重置环境</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self</span>):</span><br><span class="line">        [初始化机器人]</span><br><span class="line">        <span class="comment"># 获得对应状态</span></span><br><span class="line">        state = <span class="variable language_">self</span>.get_state()</span><br><span class="line">        <span class="keyword">return</span> state</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 关闭机器人</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">close</span>(<span class="params">self</span>):</span><br><span class="line">        [关闭机器人]</span><br></pre></td></tr></table></figure>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;基础&quot;&gt;&lt;a href=&quot;#基础&quot; class=&quot;headerlink&quot; title=&quot;基础&quot;&gt;&lt;/a&gt;基础&lt;/h3&gt;&lt;p&gt;这段代码，可以看到一个倒立摆在胡乱操作&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gym&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env_name = &lt;span class=&quot;string&quot;&gt;&amp;quot;CartPole-v0&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env = gym.make(env_name)          &lt;span class=&quot;comment&quot;&gt;# 导入环境&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;episodes = &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; episode &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, episodes + &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    state = env.reset()           &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    done = &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    score = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;not&lt;/span&gt; done:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        env.render()                           &lt;span class=&quot;comment&quot;&gt;# 渲染环境&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        action = env.action_space.sample()     &lt;span class=&quot;comment&quot;&gt;# 随机采样动作&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        n_state, reward, done, info = env.step(action)    &lt;span class=&quot;comment&quot;&gt;# 和环境交互，得到下一个状态，奖励等信息&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        score += reward                        &lt;span class=&quot;comment&quot;&gt;# 计算分数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;quot;Episode : &amp;#123;&amp;#125;, Score : &amp;#123;&amp;#125;&amp;quot;&lt;/span&gt;.&lt;span class=&quot;built_in&quot;&gt;format&lt;/span&gt;(episode, score))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env.close()     &lt;span class=&quot;comment&quot;&gt;# 关闭窗口&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;用Stable_baseline3来训练的强化学习模型,可以很好地控制这个环境:&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3 &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; DQN&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3.common.vec_env.dummy_vec_env &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; DummyVecEnv&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3.common.evaluation &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; evaluate_policy&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gym&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env_name = &lt;span class=&quot;string&quot;&gt;&amp;quot;CartPole-v0&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env = gym.make(env_name)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 把环境向量化，如果有多个环境写成列表传入DummyVecEnv中，可以用一个线程来执行多个环境，提高训练效率&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env = DummyVecEnv([&lt;span class=&quot;keyword&quot;&gt;lambda&lt;/span&gt; : env])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 定义一个DQN模型，设置其中的各个参数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;model = DQN(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;string&quot;&gt;&amp;quot;MlpPolicy&amp;quot;&lt;/span&gt;,                                &lt;span class=&quot;comment&quot;&gt;# MlpPolicy定义策略网络为MLP网络&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    env=env, &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    learning_rate=&lt;span class=&quot;number&quot;&gt;5e-4&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    batch_size=&lt;span class=&quot;number&quot;&gt;128&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    buffer_size=&lt;span class=&quot;number&quot;&gt;50000&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    learning_starts=&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    target_update_interval=&lt;span class=&quot;number&quot;&gt;250&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    policy_kwargs=&amp;#123;&lt;span class=&quot;string&quot;&gt;&amp;quot;net_arch&amp;quot;&lt;/span&gt; : [&lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;]&amp;#125;,     &lt;span class=&quot;comment&quot;&gt;# 这里代表隐藏层为2层256个节点数的网络&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    verbose=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;,                                   &lt;span class=&quot;comment&quot;&gt;# verbose=1代表打印训练信息，如果是0为不打印，2为打印调试信息&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    tensorboard_log=&lt;span class=&quot;string&quot;&gt;&amp;quot;./tensorboard/CartPole-v0/&amp;quot;&lt;/span&gt;  &lt;span class=&quot;comment&quot;&gt;# 训练数据保存目录，可以用tensorboard查看&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 开始训练&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;model.learn(total_timesteps=&lt;span class=&quot;number&quot;&gt;1e5&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 策略评估，可以看到倒立摆在平稳运行了&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=&lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;, render=true)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;#env.close()&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;quot;mean_reward:&amp;quot;&lt;/span&gt;,mean_reward,&lt;span class=&quot;string&quot;&gt;&amp;quot;std_reward:&amp;quot;&lt;/span&gt;,std_reward)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 保存模型到相应的目录&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;model.save(&lt;span class=&quot;string&quot;&gt;&amp;quot;./model/CartPole.pkl&amp;quot;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;h3 id=&quot;自定义环境&quot;&gt;&lt;a href=&quot;#自定义环境&quot; class=&quot;headerlink&quot; title=&quot;自定义环境&quot;&gt;&lt;/a&gt;自定义环境&lt;/h3&gt;&lt;p&gt;需要继承gym.Env类，然后重新其中的方法，配置一定的参数即可，格式如下：&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gym&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; gym &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; spaces&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;CustomEnv&lt;/span&gt;(gym.Env):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;string&quot;&gt;&amp;quot;&amp;quot;&amp;quot;Custom Environment that follows gym interface&amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    metadata = &amp;#123;&lt;span class=&quot;string&quot;&gt;&amp;#x27;render.modes&amp;#x27;&lt;/span&gt;: [&lt;span class=&quot;string&quot;&gt;&amp;#x27;human&amp;#x27;&lt;/span&gt;]&amp;#125;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, arg1, arg2, ...&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;super&lt;/span&gt;(CustomEnv, &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;).__init__()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# Define action and observation space&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# They must be gym.spaces objects&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# Example when using discrete actions:&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# Example for using image as input (channel-first; channel-last also works):&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.observation_space = spaces.Box(low=&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, high=&lt;span class=&quot;number&quot;&gt;255&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                            shape=(N_CHANNELS, HEIGHT, WIDTH), dtype=np.uint8)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;step&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, action&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ...&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; observation, reward, done, info&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;reset&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ...&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; observation  &lt;span class=&quot;comment&quot;&gt;# reward, done, info can&amp;#x27;t be included&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;render&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, mode=&lt;span class=&quot;string&quot;&gt;&amp;#x27;human&amp;#x27;&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ...&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;close&lt;/span&gt; (&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;pass&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;主要三个函数需要实现：&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;reset() 在每个回合最开始时执行，返回当前的观测（observation）&lt;/p&gt;
&lt;p&gt;step(action) 输入 action，智能体执行 action 与环境交互，返回获得的（新的观测、奖励、是否结束、其他）&lt;/p&gt;
&lt;p&gt;可选render(method&amp;#x3D;’human’)&amp;#96; 渲染环境&lt;/p&gt;
&lt;/blockquote&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/862669431.html"/>
    <id>https://blog.aivgg.com/posts/862669431.html</id>
    <published>2026-06-10T16:18:37.527Z</published>
    <updated>2026-06-10T16:24:04.383Z</updated>
    
    <content type="html"><![CDATA[<figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">https://github.com/eleurent/highway-env</span><br><span class="line">https://github.com/MCZhi/Driving-IRL-NGSIM/blob/main/NGSIM_env/envs/ngsim_env.py</span><br></pre></td></tr></table></figure><p>1301,3725,511,1118847351400,20.428,717.996,6451628.438,1872868.223,11,4.5,2,31.39,-10.61,2,NA,NA,NA,NA,NA,NA,1289,1304,51.13,1.63,us-101</p><p>NGSIM数据集包含四个不同的场景：US-101、I-80、Lankershim与Peachtree。US-101与I-80记录了车辆在高速公路上的行驶轨迹，Lankershim与Peachtree记录了车辆在城市道路上的行驶轨迹</p><table><thead><tr><th align="left">字段名</th><th>描述</th></tr></thead><tbody><tr><td align="left">Vehicle_Id</td><td>车辆识别号（根据进入该区域的时间升序），可重复利用。</td></tr><tr><td align="left">Frame_Id</td><td>该条数据在某一时刻的帧（按开始时间升序），同一Vehicle_ID的帧号不会重复。</td></tr><tr><td align="left">Total_Frame</td><td>该车出现在此数据集的总帧数。</td></tr><tr><td align="left">Global_Time</td><td>时间戳（ms）。</td></tr><tr><td align="left">Local_X</td><td>车辆前部中心的横向（X）坐标，以英尺为单位，相对于截面在行驶方向上的最左侧边缘。</td></tr><tr><td align="left">Local_Y</td><td>车辆前部中心的纵向（Y）坐标，以英尺为单位，相对于截面在行驶方向上的相对于路段入口的纵向边缘。</td></tr><tr><td align="left">Local_X，Local_Y</td><td>采集区域内的坐标，采集区域不同，坐标系不同，会有不同的零点。</td></tr><tr><td align="left">Global_X, Global_Y</td><td>全局坐标，只有一个零点，可用作数据筛选（以英尺为单位）。</td></tr><tr><td align="left">v_length</td><td>车辆长度（以英尺为单位）。</td></tr><tr><td align="left">v_Width</td><td>车辆宽度（以英尺为单位）。</td></tr><tr><td align="left">v_Class</td><td>车辆类型：1-摩托车，2-汽车，3-卡车。</td></tr><tr><td align="left">v_Vel</td><td><strong>车辆瞬时速度，以英尺&#x2F;秒为单位。</strong></td></tr><tr><td align="left">v_Acc</td><td><strong>车辆的瞬时加速度，以英尺&#x2F;秒平方为单位。</strong></td></tr><tr><td align="left">Lane_ID</td><td>车辆的当前车道位置。 第1车道是最左边的车道，第5车道是最右边的车道。</td></tr><tr><td align="left">O_Zone</td><td>车辆的起点区域，即车辆进入跟踪系统的位置。 研究区域有11个起源，编号从101到111。有关更多详细信息，请参阅数据分析报告。</td></tr><tr><td align="left">D_Zone</td><td>车辆的目的地区域，即车辆离开跟踪系统的地方。 研究区域中有10个目的地，从201到211编号。起点102是单向出口，因此，没有关联的目标号码202。请参阅数据分析报告以获取更多详细信息。</td></tr><tr><td align="left">Int_ID</td><td>车辆行驶的路口。 交叉点的编号为1到4，交叉点1位于最南端，交叉点4位于研究区域的最北端。 值为“ 0”表示该车辆不在交叉路口的附近，而是该车辆标识为Lankershim Boulevard的一段（下面的Section_ID）。请参阅数据分析报告以获取更多详细信息。</td></tr><tr><td align="left">Section_ID</td><td>车辆行驶的路段。 Lankershim Blvd分为五个部分（路口1的南部；路口1和2、2和3、3和4之间；路口4的北部）。 值为0表示该车辆未识别出Lankershim Boulevard的一段，并且该车辆紧邻交叉路口（上述Int_ID）。 请参阅数据分析报告以获取更多详细信息。</td></tr><tr><td align="left">Direction</td><td><strong>车辆的行驶方向</strong>。 1-东行（EB），2-北行（NB），3-西行（WB），4-南行（SB）。</td></tr><tr><td align="left">Movement</td><td><strong>车辆的运动</strong>。 1-通过（TH），2-左转（LT），3-右转（RT）。</td></tr><tr><td align="left">Preceding</td><td>同道前车的车辆编号。数值为0表示没有前面的车辆-发生在研究段的末尾和出匝道。</td></tr><tr><td align="left">Following</td><td>在同一车道上跟随本车辆的车辆的车辆ID。 值0表示没有跟随的车辆-在研究部分的开头和匝道发生。</td></tr><tr><td align="left">Space_Headway</td><td>间距提供了车辆的前中心到前一辆车辆的前中心之间的距离（英尺）。</td></tr><tr><td align="left">Time_Headway</td><td>时间进度（以秒为单位）提供了从车辆的前中心（以车辆的速度）行进到前一辆车辆的前中心的时间。</td></tr><tr><td align="left">Location</td><td>街道名称或高速公路名称。</td></tr></tbody></table><pre><code>    v_Vel  v_Acc  Lane_ID     x_prime    y_prime  v_Class  Space_Headway</code></pre><p>id<br>10_436  43.82  -1.59        1   93.445584  -1.861718        2           0.00<br>12_443  35.26   4.49        1   77.581354  -1.745590        2          52.05<br>13_432  39.48   6.21        2  100.954027  -5.892089        2          98.13<br>14_515  36.66 -11.20        5  106.701336 -15.726461        3           0.00<br>18_291  41.14   0.15        5   78.375053 -15.703906        2          92.93<br>20_414  40.01   0.00        3   61.464749  -8.023250        2         123.13<br>21_439  43.55   7.14        4   85.675927 -11.882628        2          82.24<br>22_441  34.92   0.09        2   80.107841  -5.302910        2          68.39<br>23_438  37.55   0.00        1   61.004501  -1.665427        2          54.39<br>25_436  44.98   0.00        4   55.262678 -11.570208        2          99.78<br>26_438  34.99   0.00        2   65.094002  -4.779874        2          49.26<br>27_432  40.00   0.00        1   34.377782  -1.819046        2          87.36<br>2_437   44.99   0.00        2  130.865270  -5.408676        2           0.00<br>31_465  35.02  -0.23        5   60.066326 -15.641117        2          60.07<br>32_438  34.41  -1.74        2   50.624842  -4.710684        2          47.47<br>34_451  40.01   0.28        4   37.907366 -11.704015        2          56.94<br>35_280  31.20  -2.39        5   43.132858 -16.138246        2          55.56<br>39_450  41.50   0.02        1   21.973032  -0.281026        2          40.70<br>40_391  43.69 -10.63        4   20.277430 -12.428220        2          57.84<br>47_428  40.00   0.00        3   14.348765  -8.414309        2         154.58<br>48_507  35.08  -1.87        5   32.537400 -15.481097        2          34.76<br>5_452   38.55   0.00        4  129.526284 -12.498629        2           0.00<br>8_448   39.97   0.09        4  110.742070 -12.588850        2          61.63<br>9_409   45.88  -6.52        3   98.994163  -8.774887        2           0.00</p><p><img src="/home/tony/.config/Typora/typora-user-images/image-20230314102459825.png" alt="image-20230314102459825"></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">test_make_data_loader</span>():</span><br><span class="line">    <span class="string">&quot;&quot;&quot;Tests data loader produces same results for same input in different formats.&quot;&quot;&quot;</span></span><br><span class="line">    trajs = [</span><br><span class="line">        types.Trajectory(</span><br><span class="line">            obs=np.array([<span class="number">0</span>, <span class="number">1</span>]),</span><br><span class="line">            acts=np.array([<span class="number">100</span>]),</span><br><span class="line">            infos=<span class="literal">None</span>,</span><br><span class="line">            terminal=<span class="literal">True</span>,</span><br><span class="line">        ),</span><br><span class="line">        types.Trajectory(</span><br><span class="line">            obs=np.array([<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>]),</span><br><span class="line">            acts=np.array([<span class="number">102</span>, <span class="number">103</span>]),</span><br><span class="line">            infos=<span class="literal">None</span>,</span><br><span class="line">            terminal=<span class="literal">True</span>,</span><br><span class="line">        ),</span><br><span class="line">        types.Trajectory(</span><br><span class="line">            obs=np.array([<span class="number">10</span>, <span class="number">11</span>, <span class="number">12</span>, <span class="number">13</span>]),</span><br><span class="line">            acts=np.array([<span class="number">104</span>, <span class="number">105</span>, <span class="number">106</span>]),</span><br><span class="line">            infos=<span class="literal">None</span>,</span><br><span class="line">            terminal=<span class="literal">False</span>,</span><br><span class="line">        ),</span><br><span class="line">    ]</span><br><span class="line">    trans = types.Transitions(</span><br><span class="line">        obs=np.array([<span class="number">0</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">10</span>, <span class="number">11</span>, <span class="number">12</span>]),</span><br><span class="line">        acts=np.array([<span class="number">100</span>, <span class="number">102</span>, <span class="number">103</span>, <span class="number">104</span>, <span class="number">105</span>, <span class="number">106</span>]),</span><br><span class="line">        next_obs=np.array([<span class="number">1</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">11</span>, <span class="number">12</span>, <span class="number">13</span>]),</span><br><span class="line">        dones=np.array([<span class="literal">True</span>, <span class="literal">False</span>, <span class="literal">True</span>, <span class="literal">False</span>, <span class="literal">False</span>, <span class="literal">False</span>]),</span><br><span class="line">        infos=np.array([&#123;&#125;] * <span class="number">6</span>),</span><br><span class="line">    )</span><br><span class="line">    trans_mapping = [</span><br><span class="line">        &#123;</span><br><span class="line">            <span class="string">&quot;obs&quot;</span>: np.array([<span class="number">0</span>, <span class="number">4</span>]),</span><br><span class="line">            <span class="string">&quot;acts&quot;</span>: np.array([<span class="number">100</span>, <span class="number">102</span>]),</span><br><span class="line">            <span class="string">&quot;next_obs&quot;</span>: np.array([<span class="number">1</span>, <span class="number">5</span>]),</span><br><span class="line">            <span class="string">&quot;dones&quot;</span>: np.array([<span class="literal">True</span>, <span class="literal">False</span>]),</span><br><span class="line">            <span class="string">&quot;infos&quot;</span>: np.array([&#123;&#125;, &#123;&#125;]),</span><br><span class="line">        &#125;,</span><br><span class="line">        &#123;</span><br><span class="line">            <span class="string">&quot;obs&quot;</span>: np.array([<span class="number">5</span>, <span class="number">10</span>]),</span><br><span class="line">            <span class="string">&quot;acts&quot;</span>: np.array([<span class="number">103</span>, <span class="number">104</span>]),</span><br><span class="line">            <span class="string">&quot;next_obs&quot;</span>: np.array([<span class="number">6</span>, <span class="number">11</span>]),</span><br><span class="line">            <span class="string">&quot;dones&quot;</span>: np.array([<span class="literal">True</span>, <span class="literal">False</span>]),</span><br><span class="line">            <span class="string">&quot;infos&quot;</span>: np.array([&#123;&#125;, &#123;&#125;]),</span><br><span class="line">        &#125;,</span><br><span class="line">        &#123;</span><br><span class="line">            <span class="string">&quot;obs&quot;</span>: np.array([<span class="number">11</span>, <span class="number">12</span>]),</span><br><span class="line">            <span class="string">&quot;acts&quot;</span>: np.array([<span class="number">105</span>, <span class="number">106</span>]),</span><br><span class="line">            <span class="string">&quot;next_obs&quot;</span>: np.array([<span class="number">12</span>, <span class="number">13</span>]),</span><br><span class="line">            <span class="string">&quot;dones&quot;</span>: np.array([<span class="literal">False</span>]),</span><br><span class="line">            <span class="string">&quot;infos&quot;</span>: np.array([&#123;&#125;, &#123;&#125;]),</span><br><span class="line">        &#125;,</span><br><span class="line">    ]</span><br><span class="line"></span><br><span class="line">    <span class="keyword">for</span> data <span class="keyword">in</span> [trajs, trans, trans_mapping]:</span><br><span class="line">        data_loader = base.make_data_loader(</span><br><span class="line">            data,</span><br><span class="line">            batch_size=<span class="number">2</span>,</span><br><span class="line">            data_loader_kwargs=<span class="built_in">dict</span>(shuffle=<span class="literal">False</span>, drop_last=<span class="literal">False</span>),</span><br><span class="line">        )</span><br><span class="line">        <span class="keyword">for</span> batch, expected_batch <span class="keyword">in</span> <span class="built_in">zip</span>(data_loader, trans_mapping):</span><br><span class="line">            <span class="keyword">assert</span> batch.keys() == expected_batch.keys()</span><br><span class="line">            <span class="keyword">for</span> k <span class="keyword">in</span> batch.keys():</span><br><span class="line">                v = batch[k]</span><br><span class="line">                <span class="keyword">if</span> <span class="built_in">isinstance</span>(v, th.Tensor):</span><br><span class="line">                    v = v.numpy()</span><br><span class="line">                <span class="keyword">assert</span> np.<span class="built_in">all</span>(v == expected_batch[k])</span><br><span class="line"></span><br></pre></td></tr></table></figure><ol><li>Of course you don’t have to generate the demonstrations. This is just done in the examples to make them more self-contained. You can pass your trajectories as a sequence of <code>imitation.data.types.Trajectory</code> to GAIL&#x2F;AIRL.</li></ol>]]></content>
    
    <summary type="html">
    
      &lt;figure class=&quot;highlight plaintext&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;https://github.com/eleurent/highway-env&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://github.com/MCZhi/Driving-IRL-NGSIM/blob/main/NGSIM_env/envs/ngsim_env.py&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;1301,3725,511,1118847351400,20.428,717.996,6451628.438,1872868.223,11,4.5,2,31.39,-10.61,2,NA,NA,NA,NA,NA,NA,1289,1304,51.13,1.63,us-101&lt;/p&gt;
&lt;p&gt;NGSIM数据集包含四个不同的场景：US-101、I-80、Lankershim与Peachtree。US-101与I-80记录了车辆在高速公路上的行驶轨迹，Lankershim与Peachtree记录了车辆在城市道路上的行驶轨迹&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th align=&quot;left&quot;&gt;字段名&lt;/th&gt;
&lt;th&gt;描述&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Vehicle_Id&lt;/td&gt;
&lt;td&gt;车辆识别号（根据进入该区域的时间升序），可重复利用。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Frame_Id&lt;/td&gt;
&lt;td&gt;该条数据在某一时刻的帧（按开始时间升序），同一Vehicle_ID的帧号不会重复。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Total_Frame&lt;/td&gt;
&lt;td&gt;该车出现在此数据集的总帧数。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Global_Time&lt;/td&gt;
&lt;td&gt;时间戳（ms）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Local_X&lt;/td&gt;
&lt;td&gt;车辆前部中心的横向（X）坐标，以英尺为单位，相对于截面在行驶方向上的最左侧边缘。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Local_Y&lt;/td&gt;
&lt;td&gt;车辆前部中心的纵向（Y）坐标，以英尺为单位，相对于截面在行驶方向上的相对于路段入口的纵向边缘。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Local_X，Local_Y&lt;/td&gt;
&lt;td&gt;采集区域内的坐标，采集区域不同，坐标系不同，会有不同的零点。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Global_X, Global_Y&lt;/td&gt;
&lt;td&gt;全局坐标，只有一个零点，可用作数据筛选（以英尺为单位）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;v_length&lt;/td&gt;
&lt;td&gt;车辆长度（以英尺为单位）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;v_Width&lt;/td&gt;
&lt;td&gt;车辆宽度（以英尺为单位）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;v_Class&lt;/td&gt;
&lt;td&gt;车辆类型：1-摩托车，2-汽车，3-卡车。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;v_Vel&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;车辆瞬时速度，以英尺&amp;#x2F;秒为单位。&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;v_Acc&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;车辆的瞬时加速度，以英尺&amp;#x2F;秒平方为单位。&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Lane_ID&lt;/td&gt;
&lt;td&gt;车辆的当前车道位置。 第1车道是最左边的车道，第5车道是最右边的车道。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;O_Zone&lt;/td&gt;
&lt;td&gt;车辆的起点区域，即车辆进入跟踪系统的位置。 研究区域有11个起源，编号从101到111。有关更多详细信息，请参阅数据分析报告。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;D_Zone&lt;/td&gt;
&lt;td&gt;车辆的目的地区域，即车辆离开跟踪系统的地方。 研究区域中有10个目的地，从201到211编号。起点102是单向出口，因此，没有关联的目标号码202。请参阅数据分析报告以获取更多详细信息。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Int_ID&lt;/td&gt;
&lt;td&gt;车辆行驶的路口。 交叉点的编号为1到4，交叉点1位于最南端，交叉点4位于研究区域的最北端。 值为“ 0”表示该车辆不在交叉路口的附近，而是该车辆标识为Lankershim Boulevard的一段（下面的Section_ID）。请参阅数据分析报告以获取更多详细信息。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Section_ID&lt;/td&gt;
&lt;td&gt;车辆行驶的路段。 Lankershim Blvd分为五个部分（路口1的南部；路口1和2、2和3、3和4之间；路口4的北部）。 值为0表示该车辆未识别出Lankershim Boulevard的一段，并且该车辆紧邻交叉路口（上述Int_ID）。 请参阅数据分析报告以获取更多详细信息。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Direction&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;车辆的行驶方向&lt;/strong&gt;。 1-东行（EB），2-北行（NB），3-西行（WB），4-南行（SB）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Movement&lt;/td&gt;
&lt;td&gt;&lt;strong&gt;车辆的运动&lt;/strong&gt;。 1-通过（TH），2-左转（LT），3-右转（RT）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Preceding&lt;/td&gt;
&lt;td&gt;同道前车的车辆编号。数值为0表示没有前面的车辆-发生在研究段的末尾和出匝道。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Following&lt;/td&gt;
&lt;td&gt;在同一车道上跟随本车辆的车辆的车辆ID。 值0表示没有跟随的车辆-在研究部分的开头和匝道发生。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Space_Headway&lt;/td&gt;
&lt;td&gt;间距提供了车辆的前中心到前一辆车辆的前中心之间的距离（英尺）。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Time_Headway&lt;/td&gt;
&lt;td&gt;时间进度（以秒为单位）提供了从车辆的前中心（以车辆的速度）行进到前一辆车辆的前中心的时间。&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td align=&quot;left&quot;&gt;Location&lt;/td&gt;
&lt;td&gt;街道名称或高速公路名称。&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;&lt;/table&gt;
&lt;pre&gt;&lt;code&gt;    v_Vel  v_Acc  Lane_ID     x_prime    y_prime  v_Class  Space_Headway
&lt;/code&gt;&lt;/pre&gt;
&lt;p&gt;id&lt;br&gt;10_436  43.82  -1.59        1   93.445584  -1.861718        2           0.00&lt;br&gt;12_443  35.26   4.49        1   77.581354  -1.745590        2          52.05&lt;br&gt;13_432  39.48   6.21        2  100.954027  -5.892089        2          98.13&lt;br&gt;14_515  36.66 -11.20        5  106.701336 -15.726461        3           0.00&lt;br&gt;18_291  41.14   0.15        5   78.375053 -15.703906        2          92.93&lt;br&gt;20_414  40.01   0.00        3   61.464749  -8.023250        2         123.13&lt;br&gt;21_439  43.55   7.14        4   85.675927 -11.882628        2          82.24&lt;br&gt;22_441  34.92   0.09        2   80.107841  -5.302910        2          68.39&lt;br&gt;23_438  37.55   0.00        1   61.004501  -1.665427        2          54.39&lt;br&gt;25_436  44.98   0.00        4   55.262678 -11.570208        2          99.78&lt;br&gt;26_438  34.99   0.00        2   65.094002  -4.779874        2          49.26&lt;br&gt;27_432  40.00   0.00        1   34.377782  -1.819046        2          87.36&lt;br&gt;2_437   44.99   0.00        2  130.865270  -5.408676        2           0.00&lt;br&gt;31_465  35.02  -0.23        5   60.066326 -15.641117        2          60.07&lt;br&gt;32_438  34.41  -1.74        2   50.624842  -4.710684        2          47.47&lt;br&gt;34_451  40.01   0.28        4   37.907366 -11.704015        2          56.94&lt;br&gt;35_280  31.20  -2.39        5   43.132858 -16.138246        2          55.56&lt;br&gt;39_450  41.50   0.02        1   21.973032  -0.281026        2          40.70&lt;br&gt;40_391  43.69 -10.63        4   20.277430 -12.428220        2          57.84&lt;br&gt;47_428  40.00   0.00        3   14.348765  -8.414309        2         154.58&lt;br&gt;48_507  35.08  -1.87        5   32.537400 -15.481097        2          34.76&lt;br&gt;5_452   38.55   0.00        4  129.526284 -12.498629        2           0.00&lt;br&gt;8_448   39.97   0.09        4  110.742070 -12.588850        2          61.63&lt;br&gt;9_409   45.88  -6.52        3   98.994163  -8.774887        2           0.00&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;/home/tony/.config/Typora/typora-user-images/image-20230314102459825.png&quot; alt=&quot;image-20230314102459825&quot;&gt;&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;53&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;54&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;55&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;56&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;57&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;58&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;59&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;60&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;61&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;62&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;63&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;64&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;65&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;66&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;67&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;test_make_data_loader&lt;/span&gt;():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;string&quot;&gt;&amp;quot;&amp;quot;&amp;quot;Tests data loader produces same results for same input in different formats.&amp;quot;&amp;quot;&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    trajs = [&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        types.Trajectory(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            obs=np.array([&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            acts=np.array([&lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            infos=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            terminal=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        types.Trajectory(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            obs=np.array([&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            acts=np.array([&lt;span class=&quot;number&quot;&gt;102&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;103&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            infos=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            terminal=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        types.Trajectory(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            obs=np.array([&lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;12&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;13&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            acts=np.array([&lt;span class=&quot;number&quot;&gt;104&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;105&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;106&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            infos=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            terminal=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    trans = types.Transitions(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        obs=np.array([&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;12&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        acts=np.array([&lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;102&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;103&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;104&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;105&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;106&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        next_obs=np.array([&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;12&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;13&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        dones=np.array([&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        infos=np.array([&amp;#123;&amp;#125;] * &lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    trans_mapping = [&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &amp;#123;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;obs&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;acts&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;102&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;next_obs&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;dones&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;infos&amp;quot;&lt;/span&gt;: np.array([&amp;#123;&amp;#125;, &amp;#123;&amp;#125;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &amp;#125;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &amp;#123;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;obs&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;acts&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;103&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;104&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;next_obs&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;dones&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;infos&amp;quot;&lt;/span&gt;: np.array([&amp;#123;&amp;#125;, &amp;#123;&amp;#125;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &amp;#125;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &amp;#123;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;obs&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;12&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;acts&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;105&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;106&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;next_obs&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;number&quot;&gt;12&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;13&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;dones&amp;quot;&lt;/span&gt;: np.array([&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;quot;infos&amp;quot;&lt;/span&gt;: np.array([&amp;#123;&amp;#125;, &amp;#123;&amp;#125;]),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &amp;#125;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; data &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; [trajs, trans, trans_mapping]:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        data_loader = base.make_data_loader(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            data,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            batch_size=&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            data_loader_kwargs=&lt;span class=&quot;built_in&quot;&gt;dict&lt;/span&gt;(shuffle=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, drop_last=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; batch, expected_batch &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;zip&lt;/span&gt;(data_loader, trans_mapping):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;assert&lt;/span&gt; batch.keys() == expected_batch.keys()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; k &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; batch.keys():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                v = batch[k]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;isinstance&lt;/span&gt;(v, th.Tensor):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    v = v.numpy()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;assert&lt;/span&gt; np.&lt;span class=&quot;built_in&quot;&gt;all&lt;/span&gt;(v == expected_batch[k])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;





&lt;ol&gt;
&lt;li&gt;Of course you don’t have to generate the demonstrations. This is just done in the examples to make them more self-contained. You can pass your trajectories as a sequence of &lt;code&gt;imitation.data.types.Trajectory&lt;/code&gt; to GAIL&amp;#x2F;AIRL.&lt;/li&gt;
&lt;/ol&gt;

    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/2540807166.html"/>
    <id>https://blog.aivgg.com/posts/2540807166.html</id>
    <published>2026-06-10T16:18:37.526Z</published>
    <updated>2026-06-10T16:24:04.382Z</updated>
    
    <content type="html"><![CDATA[<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> a2c_ppo_acktr.algo <span class="keyword">import</span> gail</span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br><span class="line">148</span><br><span class="line">149</span><br><span class="line">150</span><br><span class="line">151</span><br><span class="line">152</span><br><span class="line">153</span><br><span class="line">154</span><br><span class="line">155</span><br><span class="line">156</span><br><span class="line">157</span><br><span class="line">158</span><br><span class="line">159</span><br><span class="line">160</span><br><span class="line">161</span><br><span class="line">162</span><br><span class="line">163</span><br><span class="line">164</span><br><span class="line">165</span><br><span class="line">166</span><br><span class="line">167</span><br><span class="line">168</span><br><span class="line">169</span><br><span class="line">170</span><br><span class="line">171</span><br><span class="line">172</span><br><span class="line">173</span><br><span class="line">174</span><br><span class="line">175</span><br><span class="line">176</span><br><span class="line">177</span><br><span class="line">178</span><br><span class="line">179</span><br><span class="line">180</span><br><span class="line">181</span><br><span class="line">182</span><br><span class="line">183</span><br><span class="line">184</span><br><span class="line">185</span><br><span class="line">186</span><br><span class="line">187</span><br><span class="line">188</span><br><span class="line">189</span><br><span class="line">190</span><br><span class="line">191</span><br><span class="line">192</span><br><span class="line">193</span><br><span class="line">194</span><br><span class="line">195</span><br><span class="line">196</span><br><span class="line">197</span><br><span class="line">198</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> copy</span><br><span class="line"><span class="keyword">import</span> glob</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> deque</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br><span class="line"><span class="keyword">import</span> torch.optim <span class="keyword">as</span> optim</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> a2c_ppo_acktr <span class="keyword">import</span> algo, utils</span><br><span class="line"><span class="keyword">from</span> a2c_ppo_acktr.algo <span class="keyword">import</span> gail</span><br><span class="line"><span class="keyword">from</span> a2c_ppo_acktr.arguments <span class="keyword">import</span> get_args</span><br><span class="line"><span class="keyword">from</span> a2c_ppo_acktr.envs <span class="keyword">import</span> make_vec_envs</span><br><span class="line"><span class="keyword">from</span> a2c_ppo_acktr.model <span class="keyword">import</span> Policy</span><br><span class="line"><span class="keyword">from</span> a2c_ppo_acktr.storage <span class="keyword">import</span> RolloutStorage</span><br><span class="line"><span class="keyword">from</span> evaluation <span class="keyword">import</span> evaluate</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>():</span><br><span class="line">    args = get_args()</span><br><span class="line"></span><br><span class="line">    torch.manual_seed(args.seed)</span><br><span class="line">    torch.cuda.manual_seed_all(args.seed)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span> args.cuda <span class="keyword">and</span> torch.cuda.is_available() <span class="keyword">and</span> args.cuda_deterministic:</span><br><span class="line">        torch.backends.cudnn.benchmark = <span class="literal">False</span></span><br><span class="line">        torch.backends.cudnn.deterministic = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line">    log_dir = os.path.expanduser(args.log_dir)</span><br><span class="line">    eval_log_dir = log_dir + <span class="string">&quot;_eval&quot;</span></span><br><span class="line">    utils.cleanup_log_dir(log_dir)</span><br><span class="line">    utils.cleanup_log_dir(eval_log_dir)</span><br><span class="line"></span><br><span class="line">    torch.set_num_threads(<span class="number">1</span>)</span><br><span class="line">    device = torch.device(<span class="string">&quot;cuda:0&quot;</span> <span class="keyword">if</span> args.cuda <span class="keyword">else</span> <span class="string">&quot;cpu&quot;</span>)</span><br><span class="line"></span><br><span class="line">    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,</span><br><span class="line">                         args.gamma, args.log_dir, device, <span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">    actor_critic = Policy(</span><br><span class="line">        envs.observation_space.shape,</span><br><span class="line">        envs.action_space,</span><br><span class="line">        base_kwargs=&#123;<span class="string">&#x27;recurrent&#x27;</span>: args.recurrent_policy&#125;)</span><br><span class="line">    actor_critic.to(device)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span> args.algo == <span class="string">&#x27;a2c&#x27;</span>:</span><br><span class="line">        agent = algo.A2C_ACKTR(</span><br><span class="line">            actor_critic,</span><br><span class="line">            args.value_loss_coef,</span><br><span class="line">            args.entropy_coef,</span><br><span class="line">            lr=args.lr,</span><br><span class="line">            eps=args.eps,</span><br><span class="line">            alpha=args.alpha,</span><br><span class="line">            max_grad_norm=args.max_grad_norm)</span><br><span class="line">    <span class="keyword">elif</span> args.algo == <span class="string">&#x27;ppo&#x27;</span>:</span><br><span class="line">        agent = algo.PPO(</span><br><span class="line">            actor_critic,</span><br><span class="line">            args.clip_param,</span><br><span class="line">            args.ppo_epoch,</span><br><span class="line">            args.num_mini_batch,</span><br><span class="line">            args.value_loss_coef,</span><br><span class="line">            args.entropy_coef,</span><br><span class="line">            lr=args.lr,</span><br><span class="line">            eps=args.eps,</span><br><span class="line">            max_grad_norm=args.max_grad_norm)</span><br><span class="line">    <span class="keyword">elif</span> args.algo == <span class="string">&#x27;acktr&#x27;</span>:</span><br><span class="line">        agent = algo.A2C_ACKTR(</span><br><span class="line">            actor_critic, args.value_loss_coef, args.entropy_coef, acktr=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">if</span> args.gail:</span><br><span class="line">        <span class="keyword">assert</span> <span class="built_in">len</span>(envs.observation_space.shape) == <span class="number">1</span></span><br><span class="line">        discr = gail.Discriminator(</span><br><span class="line">            envs.observation_space.shape[<span class="number">0</span>] + envs.action_space.shape[<span class="number">0</span>], <span class="number">100</span>,</span><br><span class="line">            device)</span><br><span class="line">        file_name = os.path.join(</span><br><span class="line">            args.gail_experts_dir, <span class="string">&quot;trajs_&#123;&#125;.pt&quot;</span>.<span class="built_in">format</span>(</span><br><span class="line">                args.env_name.split(<span class="string">&#x27;-&#x27;</span>)[<span class="number">0</span>].lower()))</span><br><span class="line">        </span><br><span class="line">        expert_dataset = gail.ExpertDataset(</span><br><span class="line">            file_name, num_trajectories=<span class="number">4</span>, subsample_frequency=<span class="number">20</span>)</span><br><span class="line">        drop_last = <span class="built_in">len</span>(expert_dataset) &gt; args.gail_batch_size</span><br><span class="line">        gail_train_loader = torch.utils.data.DataLoader(</span><br><span class="line">            dataset=expert_dataset,</span><br><span class="line">            batch_size=args.gail_batch_size,</span><br><span class="line">            shuffle=<span class="literal">True</span>,</span><br><span class="line">            drop_last=drop_last)</span><br><span class="line"></span><br><span class="line">    rollouts = RolloutStorage(args.num_steps, args.num_processes,</span><br><span class="line">                              envs.observation_space.shape, envs.action_space,</span><br><span class="line">                              actor_critic.recurrent_hidden_state_size)</span><br><span class="line"></span><br><span class="line">    obs = envs.reset()</span><br><span class="line">    rollouts.obs[<span class="number">0</span>].copy_(obs)</span><br><span class="line">    rollouts.to(device)</span><br><span class="line"></span><br><span class="line">    episode_rewards = deque(maxlen=<span class="number">10</span>)</span><br><span class="line"></span><br><span class="line">    start = time.time()</span><br><span class="line">    num_updates = <span class="built_in">int</span>(</span><br><span class="line">        args.num_env_steps) // args.num_steps // args.num_processes</span><br><span class="line">    <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(num_updates):</span><br><span class="line"></span><br><span class="line">        <span class="keyword">if</span> args.use_linear_lr_decay:</span><br><span class="line">            <span class="comment"># decrease learning rate linearly</span></span><br><span class="line">            utils.update_linear_schedule(</span><br><span class="line">                agent.optimizer, j, num_updates,</span><br><span class="line">                agent.optimizer.lr <span class="keyword">if</span> args.algo == <span class="string">&quot;acktr&quot;</span> <span class="keyword">else</span> args.lr)</span><br><span class="line"></span><br><span class="line">        <span class="keyword">for</span> step <span class="keyword">in</span> <span class="built_in">range</span>(args.num_steps):</span><br><span class="line">            <span class="comment"># Sample actions</span></span><br><span class="line">            <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(</span><br><span class="line">                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],</span><br><span class="line">                    rollouts.masks[step])</span><br><span class="line"></span><br><span class="line">            <span class="comment"># Obser reward and next obs</span></span><br><span class="line">            obs, reward, done, infos = envs.step(action)</span><br><span class="line"></span><br><span class="line">            <span class="keyword">for</span> info <span class="keyword">in</span> infos:</span><br><span class="line">                <span class="keyword">if</span> <span class="string">&#x27;episode&#x27;</span> <span class="keyword">in</span> info.keys():</span><br><span class="line">                    episode_rewards.append(info[<span class="string">&#x27;episode&#x27;</span>][<span class="string">&#x27;r&#x27;</span>])</span><br><span class="line"></span><br><span class="line">            <span class="comment"># If done then clean the history of observations.</span></span><br><span class="line">            masks = torch.FloatTensor(</span><br><span class="line">                [[<span class="number">0.0</span>] <span class="keyword">if</span> done_ <span class="keyword">else</span> [<span class="number">1.0</span>] <span class="keyword">for</span> done_ <span class="keyword">in</span> done])</span><br><span class="line">            bad_masks = torch.FloatTensor(</span><br><span class="line">                [[<span class="number">0.0</span>] <span class="keyword">if</span> <span class="string">&#x27;bad_transition&#x27;</span> <span class="keyword">in</span> info.keys() <span class="keyword">else</span> [<span class="number">1.0</span>]</span><br><span class="line">                 <span class="keyword">for</span> info <span class="keyword">in</span> infos])</span><br><span class="line">            rollouts.insert(obs, recurrent_hidden_states, action,</span><br><span class="line">                            action_log_prob, value, reward, masks, bad_masks)</span><br><span class="line"></span><br><span class="line">        <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">            next_value = actor_critic.get_value(</span><br><span class="line">                rollouts.obs[-<span class="number">1</span>], rollouts.recurrent_hidden_states[-<span class="number">1</span>],</span><br><span class="line">                rollouts.masks[-<span class="number">1</span>]).detach()</span><br><span class="line"></span><br><span class="line">        <span class="keyword">if</span> args.gail:</span><br><span class="line">            <span class="keyword">if</span> j &gt;= <span class="number">10</span>:</span><br><span class="line">                envs.venv.<span class="built_in">eval</span>()</span><br><span class="line"></span><br><span class="line">            gail_epoch = args.gail_epoch</span><br><span class="line">            <span class="keyword">if</span> j &lt; <span class="number">10</span>:</span><br><span class="line">                gail_epoch = <span class="number">100</span>  <span class="comment"># Warm up</span></span><br><span class="line">            <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(gail_epoch):</span><br><span class="line">                discr.update(gail_train_loader, rollouts,</span><br><span class="line">                             utils.get_vec_normalize(envs)._obfilt)</span><br><span class="line"></span><br><span class="line">            <span class="keyword">for</span> step <span class="keyword">in</span> <span class="built_in">range</span>(args.num_steps):</span><br><span class="line">                rollouts.rewards[step] = discr.predict_reward(</span><br><span class="line">                    rollouts.obs[step], rollouts.actions[step], args.gamma,</span><br><span class="line">                    rollouts.masks[step])</span><br><span class="line"></span><br><span class="line">        rollouts.compute_returns(next_value, args.use_gae, args.gamma,</span><br><span class="line">                                 args.gae_lambda, args.use_proper_time_limits)</span><br><span class="line"></span><br><span class="line">        value_loss, action_loss, dist_entropy = agent.update(rollouts)</span><br><span class="line"></span><br><span class="line">        rollouts.after_update()</span><br><span class="line"></span><br><span class="line">        <span class="comment"># save for every interval-th episode or for the last epoch</span></span><br><span class="line">        <span class="keyword">if</span> (j % args.save_interval == <span class="number">0</span></span><br><span class="line">                <span class="keyword">or</span> j == num_updates - <span class="number">1</span>) <span class="keyword">and</span> args.save_dir != <span class="string">&quot;&quot;</span>:</span><br><span class="line">            save_path = os.path.join(args.save_dir, args.algo)</span><br><span class="line">            <span class="keyword">try</span>:</span><br><span class="line">                os.makedirs(save_path)</span><br><span class="line">            <span class="keyword">except</span> OSError:</span><br><span class="line">                <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line">            torch.save([</span><br><span class="line">                actor_critic,</span><br><span class="line">                <span class="built_in">getattr</span>(utils.get_vec_normalize(envs), <span class="string">&#x27;obs_rms&#x27;</span>, <span class="literal">None</span>)</span><br><span class="line">            ], os.path.join(save_path, args.env_name + <span class="string">&quot;.pt&quot;</span>))</span><br><span class="line"></span><br><span class="line">        <span class="keyword">if</span> j % args.log_interval == <span class="number">0</span> <span class="keyword">and</span> <span class="built_in">len</span>(episode_rewards) &gt; <span class="number">1</span>:</span><br><span class="line">            total_num_steps = (j + <span class="number">1</span>) * args.num_processes * args.num_steps</span><br><span class="line">            end = time.time()</span><br><span class="line">            <span class="built_in">print</span>(</span><br><span class="line">                <span class="string">&quot;Updates &#123;&#125;, num timesteps &#123;&#125;, FPS &#123;&#125; \n Last &#123;&#125; training episodes: mean/median reward &#123;:.1f&#125;/&#123;:.1f&#125;, min/max reward &#123;:.1f&#125;/&#123;:.1f&#125;\n&quot;</span></span><br><span class="line">                .<span class="built_in">format</span>(j, total_num_steps,</span><br><span class="line">                        <span class="built_in">int</span>(total_num_steps / (end - start)),</span><br><span class="line">                        <span class="built_in">len</span>(episode_rewards), np.mean(episode_rewards),</span><br><span class="line">                        np.median(episode_rewards), np.<span class="built_in">min</span>(episode_rewards),</span><br><span class="line">                        np.<span class="built_in">max</span>(episode_rewards), dist_entropy, value_loss,</span><br><span class="line">                        action_loss))</span><br><span class="line"></span><br><span class="line">        <span class="keyword">if</span> (args.eval_interval <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span> <span class="keyword">and</span> <span class="built_in">len</span>(episode_rewards) &gt; <span class="number">1</span></span><br><span class="line">                <span class="keyword">and</span> j % args.eval_interval == <span class="number">0</span>):</span><br><span class="line">            obs_rms = utils.get_vec_normalize(envs).obs_rms</span><br><span class="line">            evaluate(actor_critic, obs_rms, args.env_name, args.seed,</span><br><span class="line">                     args.num_processes, eval_log_dir, device)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">&quot;__main__&quot;</span>:</span><br><span class="line">    main()</span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br><span class="line">148</span><br><span class="line">149</span><br><span class="line">150</span><br><span class="line">151</span><br><span class="line">152</span><br><span class="line">153</span><br><span class="line">154</span><br><span class="line">155</span><br><span class="line">156</span><br><span class="line">157</span><br><span class="line">158</span><br><span class="line">159</span><br><span class="line">160</span><br><span class="line">161</span><br><span class="line">162</span><br><span class="line">163</span><br><span class="line">164</span><br><span class="line">165</span><br><span class="line">166</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> h5py</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br><span class="line"><span class="keyword">import</span> torch.utils.data</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> autograd</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.running_mean_std <span class="keyword">import</span> RunningMeanStd</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Discriminator</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, input_dim, hidden_dim, device</span>):</span><br><span class="line">        <span class="built_in">super</span>(Discriminator, <span class="variable language_">self</span>).__init__()</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.device = device</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.trunk = nn.Sequential(</span><br><span class="line">            nn.Linear(input_dim, hidden_dim), nn.Tanh(),</span><br><span class="line">            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),</span><br><span class="line">            nn.Linear(hidden_dim, <span class="number">1</span>)).to(device)</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.trunk.train()</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.optimizer = torch.optim.Adam(<span class="variable language_">self</span>.trunk.parameters())</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.returns = <span class="literal">None</span></span><br><span class="line">        <span class="variable language_">self</span>.ret_rms = RunningMeanStd(shape=())</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">compute_grad_pen</span>(<span class="params">self,</span></span><br><span class="line"><span class="params">                         expert_state,</span></span><br><span class="line"><span class="params">                         expert_action,</span></span><br><span class="line"><span class="params">                         policy_state,</span></span><br><span class="line"><span class="params">                         policy_action,</span></span><br><span class="line"><span class="params">                         lambda_=<span class="number">10</span></span>):</span><br><span class="line">        alpha = torch.rand(expert_state.size(<span class="number">0</span>), <span class="number">1</span>)</span><br><span class="line">        expert_data = torch.cat([expert_state, expert_action], dim=<span class="number">1</span>)</span><br><span class="line">        policy_data = torch.cat([policy_state, policy_action], dim=<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line">        alpha = alpha.expand_as(expert_data).to(expert_data.device)</span><br><span class="line"></span><br><span class="line">        mixup_data = alpha * expert_data + (<span class="number">1</span> - alpha) * policy_data</span><br><span class="line">        mixup_data.requires_grad = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line">        disc = <span class="variable language_">self</span>.trunk(mixup_data)</span><br><span class="line">        ones = torch.ones(disc.size()).to(disc.device)</span><br><span class="line">        grad = autograd.grad(</span><br><span class="line">            outputs=disc,</span><br><span class="line">            inputs=mixup_data,</span><br><span class="line">            grad_outputs=ones,</span><br><span class="line">            create_graph=<span class="literal">True</span>,</span><br><span class="line">            retain_graph=<span class="literal">True</span>,</span><br><span class="line">            only_inputs=<span class="literal">True</span>)[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line">        grad_pen = lambda_ * (grad.norm(<span class="number">2</span>, dim=<span class="number">1</span>) - <span class="number">1</span>).<span class="built_in">pow</span>(<span class="number">2</span>).mean()</span><br><span class="line">        <span class="keyword">return</span> grad_pen</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">update</span>(<span class="params">self, expert_loader, rollouts, obsfilt=<span class="literal">None</span></span>):</span><br><span class="line">        <span class="variable language_">self</span>.train()</span><br><span class="line"></span><br><span class="line">        policy_data_generator = rollouts.feed_forward_generator(</span><br><span class="line">            <span class="literal">None</span>, mini_batch_size=expert_loader.batch_size)</span><br><span class="line"></span><br><span class="line">        loss = <span class="number">0</span></span><br><span class="line">        n = <span class="number">0</span></span><br><span class="line">        <span class="keyword">for</span> expert_batch, policy_batch <span class="keyword">in</span> <span class="built_in">zip</span>(expert_loader,</span><br><span class="line">                                              policy_data_generator):</span><br><span class="line">            policy_state, policy_action = policy_batch[<span class="number">0</span>], policy_batch[<span class="number">2</span>]</span><br><span class="line">            policy_d = <span class="variable language_">self</span>.trunk(</span><br><span class="line">                torch.cat([policy_state, policy_action], dim=<span class="number">1</span>))</span><br><span class="line"></span><br><span class="line">            expert_state, expert_action = expert_batch</span><br><span class="line">            expert_state = obsfilt(expert_state.numpy(), update=<span class="literal">False</span>)</span><br><span class="line">            expert_state = torch.FloatTensor(expert_state).to(<span class="variable language_">self</span>.device)</span><br><span class="line">            expert_action = expert_action.to(<span class="variable language_">self</span>.device)</span><br><span class="line">            expert_d = <span class="variable language_">self</span>.trunk(</span><br><span class="line">                torch.cat([expert_state, expert_action], dim=<span class="number">1</span>))</span><br><span class="line"></span><br><span class="line">            expert_loss = F.binary_cross_entropy_with_logits(</span><br><span class="line">                expert_d,</span><br><span class="line">                torch.ones(expert_d.size()).to(<span class="variable language_">self</span>.device))</span><br><span class="line">            policy_loss = F.binary_cross_entropy_with_logits(</span><br><span class="line">                policy_d,</span><br><span class="line">                torch.zeros(policy_d.size()).to(<span class="variable language_">self</span>.device))</span><br><span class="line"></span><br><span class="line">            gail_loss = expert_loss + policy_loss</span><br><span class="line">            grad_pen = <span class="variable language_">self</span>.compute_grad_pen(expert_state, expert_action,</span><br><span class="line">                                             policy_state, policy_action)</span><br><span class="line"></span><br><span class="line">            loss += (gail_loss + grad_pen).item()</span><br><span class="line">            n += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">            <span class="variable language_">self</span>.optimizer.zero_grad()</span><br><span class="line">            (gail_loss + grad_pen).backward()</span><br><span class="line">            <span class="variable language_">self</span>.optimizer.step()</span><br><span class="line">        <span class="keyword">return</span> loss / n</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">predict_reward</span>(<span class="params">self, state, action, gamma, masks, update_rms=<span class="literal">True</span></span>):</span><br><span class="line">        <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">            <span class="variable language_">self</span>.<span class="built_in">eval</span>()</span><br><span class="line">            d = <span class="variable language_">self</span>.trunk(torch.cat([state, action], dim=<span class="number">1</span>))</span><br><span class="line">            s = torch.sigmoid(d)</span><br><span class="line">            reward = s.log() - (<span class="number">1</span> - s).log()</span><br><span class="line">            <span class="keyword">if</span> <span class="variable language_">self</span>.returns <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line">                <span class="variable language_">self</span>.returns = reward.clone()</span><br><span class="line"></span><br><span class="line">            <span class="keyword">if</span> update_rms:</span><br><span class="line">                <span class="variable language_">self</span>.returns = <span class="variable language_">self</span>.returns * masks * gamma + reward</span><br><span class="line">                <span class="variable language_">self</span>.ret_rms.update(<span class="variable language_">self</span>.returns.cpu().numpy())</span><br><span class="line"></span><br><span class="line">            <span class="keyword">return</span> reward / np.sqrt(<span class="variable language_">self</span>.ret_rms.var[<span class="number">0</span>] + <span class="number">1e-8</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">ExpertDataset</span>(torch.utils.data.Dataset):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, file_name, num_trajectories=<span class="number">4</span>, subsample_frequency=<span class="number">20</span></span>):</span><br><span class="line">        all_trajectories = torch.load(file_name)</span><br><span class="line">        </span><br><span class="line">        perm = torch.randperm(all_trajectories[<span class="string">&#x27;states&#x27;</span>].size(<span class="number">0</span>))</span><br><span class="line">        idx = perm[:num_trajectories]</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.trajectories = &#123;&#125;</span><br><span class="line">        </span><br><span class="line">        <span class="comment"># See https://github.com/pytorch/pytorch/issues/14886</span></span><br><span class="line">        <span class="comment"># .long() for fixing bug in torch v0.4.1</span></span><br><span class="line">        start_idx = torch.randint(</span><br><span class="line">            <span class="number">0</span>, subsample_frequency, size=(num_trajectories, )).long()</span><br><span class="line"></span><br><span class="line">        <span class="keyword">for</span> k, v <span class="keyword">in</span> all_trajectories.items():</span><br><span class="line">            data = v[idx]</span><br><span class="line"></span><br><span class="line">            <span class="keyword">if</span> k != <span class="string">&#x27;lengths&#x27;</span>:</span><br><span class="line">                samples = []</span><br><span class="line">                <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(num_trajectories):</span><br><span class="line">                    samples.append(data[i, start_idx[i]::subsample_frequency])</span><br><span class="line">                <span class="variable language_">self</span>.trajectories[k] = torch.stack(samples)</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                <span class="variable language_">self</span>.trajectories[k] = data // subsample_frequency</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.i2traj_idx = &#123;&#125;</span><br><span class="line">        <span class="variable language_">self</span>.i2i = &#123;&#125;</span><br><span class="line">        </span><br><span class="line">        <span class="variable language_">self</span>.length = <span class="variable language_">self</span>.trajectories[<span class="string">&#x27;lengths&#x27;</span>].<span class="built_in">sum</span>().item()</span><br><span class="line"></span><br><span class="line">        traj_idx = <span class="number">0</span></span><br><span class="line">        i = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.get_idx = []</span><br><span class="line">        </span><br><span class="line">        <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(<span class="variable language_">self</span>.length):</span><br><span class="line">            </span><br><span class="line">            <span class="keyword">while</span> <span class="variable language_">self</span>.trajectories[<span class="string">&#x27;lengths&#x27;</span>][traj_idx].item() &lt;= i:</span><br><span class="line">                i -= <span class="variable language_">self</span>.trajectories[<span class="string">&#x27;lengths&#x27;</span>][traj_idx].item()</span><br><span class="line">                traj_idx += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">            <span class="variable language_">self</span>.get_idx.append((traj_idx, i))</span><br><span class="line"></span><br><span class="line">            i += <span class="number">1</span></span><br><span class="line">            </span><br><span class="line">            </span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__len__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.length</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__getitem__</span>(<span class="params">self, i</span>):</span><br><span class="line">        traj_idx, i = <span class="variable language_">self</span>.get_idx[i]</span><br><span class="line"></span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.trajectories[<span class="string">&#x27;states&#x27;</span>][traj_idx][i], <span class="variable language_">self</span>.trajectories[</span><br><span class="line">            <span class="string">&#x27;actions&#x27;</span>][traj_idx][i]</span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">from</span> stable_baselines3 <span class="keyword">import</span> PPO</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.evaluation <span class="keyword">import</span> evaluate_policy</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.vec_env <span class="keyword">import</span> DummyVecEnv</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.ppo <span class="keyword">import</span> MlpPolicy</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> imitation.algorithms.adversarial.gail <span class="keyword">import</span> GAIL</span><br><span class="line"><span class="keyword">from</span> imitation.data <span class="keyword">import</span> rollout</span><br><span class="line"><span class="keyword">from</span> imitation.data.wrappers <span class="keyword">import</span> RolloutInfoWrapper</span><br><span class="line"><span class="keyword">from</span> imitation.rewards.reward_nets <span class="keyword">import</span> BasicRewardNet</span><br><span class="line"><span class="keyword">from</span> imitation.util.networks <span class="keyword">import</span> RunningNorm</span><br><span class="line"><span class="keyword">from</span> imitation.util.util <span class="keyword">import</span> make_vec_env</span><br><span class="line"></span><br><span class="line">rng = np.random.default_rng(<span class="number">0</span>)</span><br><span class="line"></span><br><span class="line">env = gym.make(<span class="string">&quot;seals/CartPole-v0&quot;</span>)</span><br><span class="line">expert = PPO(policy=MlpPolicy, env=env, n_steps=<span class="number">64</span>)</span><br><span class="line">expert.learn(<span class="number">1000</span>)</span><br><span class="line"></span><br><span class="line">rollouts = rollout.rollout(</span><br><span class="line">    expert,</span><br><span class="line">    make_vec_env(</span><br><span class="line">        <span class="string">&quot;seals/CartPole-v0&quot;</span>,</span><br><span class="line">        n_envs=<span class="number">5</span>,</span><br><span class="line">        post_wrappers=[<span class="keyword">lambda</span> env, _: RolloutInfoWrapper(env)],</span><br><span class="line">        rng=rng,</span><br><span class="line">    ),</span><br><span class="line">    rollout.make_sample_until(min_timesteps=<span class="literal">None</span>, min_episodes=<span class="number">60</span>),</span><br><span class="line">    rng=rng,</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line">venv = make_vec_env(<span class="string">&quot;seals/CartPole-v0&quot;</span>, n_envs=<span class="number">8</span>, rng=rng)</span><br><span class="line">learner = PPO(env=venv, policy=MlpPolicy)</span><br><span class="line">reward_net = BasicRewardNet(</span><br><span class="line">    venv.observation_space,</span><br><span class="line">    venv.action_space,</span><br><span class="line">    normalize_input_layer=RunningNorm,</span><br><span class="line">)</span><br><span class="line">gail_trainer = GAIL(</span><br><span class="line">    demonstrations=rollouts,</span><br><span class="line">    demo_batch_size=<span class="number">1024</span>,</span><br><span class="line">    gen_replay_buffer_capacity=<span class="number">2048</span>,</span><br><span class="line">    n_disc_updates_per_round=<span class="number">4</span>,</span><br><span class="line">    venv=venv,</span><br><span class="line">    gen_algo=learner,</span><br><span class="line">    reward_net=reward_net,</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line">gail_trainer.train(<span class="number">20000</span>)</span><br><span class="line">rewards, _ = evaluate_policy(learner, venv, <span class="number">100</span>, return_episode_rewards=<span class="literal">True</span>)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;Rewards:&quot;</span>, rewards)</span><br></pre></td></tr></table></figure><p>TensorFlow</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> argparse</span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"><span class="keyword">from</span> network_models.policy_net <span class="keyword">import</span> Policy_net</span><br><span class="line"><span class="keyword">from</span> network_models.discriminator <span class="keyword">import</span> Discriminator</span><br><span class="line"><span class="keyword">from</span> algo.ppo <span class="keyword">import</span> PPOTrain</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">argparser</span>():</span><br><span class="line">    parser = argparse.ArgumentParser()</span><br><span class="line">    parser.add_argument(<span class="string">&#x27;--logdir&#x27;</span>, <span class="built_in">help</span>=<span class="string">&#x27;log directory&#x27;</span>, default=<span class="string">&#x27;log/train/gail&#x27;</span>)</span><br><span class="line">    parser.add_argument(<span class="string">&#x27;--savedir&#x27;</span>, <span class="built_in">help</span>=<span class="string">&#x27;save directory&#x27;</span>, default=<span class="string">&#x27;trained_models/gail&#x27;</span>)</span><br><span class="line">    parser.add_argument(<span class="string">&#x27;--gamma&#x27;</span>, default=<span class="number">0.95</span>)</span><br><span class="line">    parser.add_argument(<span class="string">&#x27;--iteration&#x27;</span>, default=<span class="built_in">int</span>(<span class="number">1e4</span>))</span><br><span class="line">    <span class="keyword">return</span> parser.parse_args()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>(<span class="params">args</span>):</span><br><span class="line">    env = gym.make(<span class="string">&#x27;CartPole-v0&#x27;</span>)</span><br><span class="line">    env.seed(<span class="number">0</span>)</span><br><span class="line">    ob_space = env.observation_space</span><br><span class="line">    Policy = Policy_net(<span class="string">&#x27;policy&#x27;</span>, env)</span><br><span class="line">    Old_Policy = Policy_net(<span class="string">&#x27;old_policy&#x27;</span>, env)</span><br><span class="line">    PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma)</span><br><span class="line">    D = Discriminator(env)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 得到专家的观测和行动</span></span><br><span class="line">    expert_observations = np.genfromtxt(<span class="string">&#x27;trajectory/observations.csv&#x27;</span>)</span><br><span class="line">    expert_actions = np.genfromtxt(<span class="string">&#x27;trajectory/actions.csv&#x27;</span>, dtype=np.int32)</span><br><span class="line"></span><br><span class="line">    saver = tf.train.Saver()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">with</span> tf.Session() <span class="keyword">as</span> sess:</span><br><span class="line">        writer = tf.summary.FileWriter(args.logdir, sess.graph)</span><br><span class="line">        sess.run(tf.global_variables_initializer())</span><br><span class="line"></span><br><span class="line">        obs = env.reset()</span><br><span class="line">        success_num = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">        <span class="keyword">for</span> iteration <span class="keyword">in</span> <span class="built_in">range</span>(args.iteration):</span><br><span class="line">            observations = []</span><br><span class="line">            actions = []</span><br><span class="line">            rewards = []</span><br><span class="line">            v_preds = []</span><br><span class="line">            run_policy_steps = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">            <span class="keyword">while</span> <span class="literal">True</span>:</span><br><span class="line">                run_policy_steps += <span class="number">1</span></span><br><span class="line">                obs = np.stack([obs]).astype(dtype=np.float32)</span><br><span class="line">                act, v_pred = Policy.act(obs = obs,stochastic = <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">                act = np.asscalar(act)</span><br><span class="line">                v_pred = np.asscalar(v_pred)</span><br><span class="line"></span><br><span class="line">                next_obs,reward,done,info = env.step(act)</span><br><span class="line"></span><br><span class="line">                observations.append(obs)</span><br><span class="line">                actions.append(act)</span><br><span class="line">                rewards.append(reward)</span><br><span class="line">                v_preds.append(v_pred)</span><br><span class="line"></span><br><span class="line">                <span class="keyword">if</span> done:</span><br><span class="line">                    next_obs = np.stack([next_obs]).astype(dtype=np.float32)  <span class="comment"># prepare to feed placeholder Policy.obs</span></span><br><span class="line">                    _, v_pred = Policy.act(obs=next_obs, stochastic=<span class="literal">True</span>)</span><br><span class="line">                    v_preds_next = v_preds[<span class="number">1</span>:] + [np.asscalar(v_pred)]</span><br><span class="line">                    obs = env.reset()</span><br><span class="line">                    <span class="keyword">break</span></span><br><span class="line">                <span class="keyword">else</span>:</span><br><span class="line">                    obs = next_obs</span><br><span class="line"></span><br><span class="line">            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=<span class="string">&#x27;episode_length&#x27;</span>, simple_value=run_policy_steps)])</span><br><span class="line">                               , iteration)</span><br><span class="line">            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=<span class="string">&#x27;episode_reward&#x27;</span>, simple_value=<span class="built_in">sum</span>(rewards))])</span><br><span class="line">                               , iteration)</span><br><span class="line"></span><br><span class="line">            <span class="keyword">if</span> <span class="built_in">sum</span>(rewards) &gt;= <span class="number">195</span>:</span><br><span class="line">                success_num += <span class="number">1</span></span><br><span class="line">                <span class="keyword">if</span> success_num &gt;= <span class="number">100</span>:</span><br><span class="line">                    saver.save(sess, args.savedir + <span class="string">&#x27;/model.ckpt&#x27;</span>)</span><br><span class="line">                    <span class="built_in">print</span>(<span class="string">&#x27;Clear!! Model saved.&#x27;</span>)</span><br><span class="line">                    <span class="keyword">break</span></span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                success_num = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">            observations = np.reshape(observations,newshape=[-<span class="number">1</span>] + <span class="built_in">list</span>(ob_space.shape))</span><br><span class="line">            actions = np.array(actions).astype(dtype = np.int32)</span><br><span class="line"></span><br><span class="line">            <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">2</span>):</span><br><span class="line">                D.train(expert_s = expert_observations,</span><br><span class="line">                        expert_a = expert_actions,</span><br><span class="line">                        agent_s = observations,</span><br><span class="line">                        agent_a = actions)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">            d_rewards = D.get_rewards(agent_s=observations,agent_a = actions)</span><br><span class="line">            d_rewards = np.reshape(d_rewards,newshape=[-<span class="number">1</span>]).astype(dtype=np.float32)</span><br><span class="line"></span><br><span class="line">            gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next)</span><br><span class="line">            gaes = np.array(gaes).astype(dtype=np.float32)</span><br><span class="line">            <span class="comment"># gaes = (gaes - gaes.mean()) / gaes.std()</span></span><br><span class="line">            v_preds_next = np.array(v_preds_next).astype(dtype=np.float32)</span><br><span class="line"></span><br><span class="line">            <span class="comment"># train policy</span></span><br><span class="line">            inp = [observations, actions, gaes, d_rewards, v_preds_next]</span><br><span class="line">            PPO.assign_policy_parameters()</span><br><span class="line">            <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">6</span>):</span><br><span class="line">                sample_indices = np.random.randint(low=<span class="number">0</span>, high=observations.shape[<span class="number">0</span>],</span><br><span class="line">                                                   size=<span class="number">32</span>)  <span class="comment"># indices are in [low, high)</span></span><br><span class="line">                sampled_inp = [np.take(a=a, indices=sample_indices, axis=<span class="number">0</span>) <span class="keyword">for</span> a <span class="keyword">in</span> inp]  <span class="comment"># sample training data</span></span><br><span class="line">                PPO.train(obs=sampled_inp[<span class="number">0</span>],</span><br><span class="line">                          actions=sampled_inp[<span class="number">1</span>],</span><br><span class="line">                          gaes=sampled_inp[<span class="number">2</span>],</span><br><span class="line">                          rewards=sampled_inp[<span class="number">3</span>],</span><br><span class="line">                          v_preds_next=sampled_inp[<span class="number">4</span>])</span><br><span class="line"></span><br><span class="line">            summary = PPO.get_summary(obs=inp[<span class="number">0</span>],</span><br><span class="line">                                      actions=inp[<span class="number">1</span>],</span><br><span class="line">                                      gaes=inp[<span class="number">2</span>],</span><br><span class="line">                                      rewards=inp[<span class="number">3</span>],</span><br><span class="line">                                      v_preds_next=inp[<span class="number">4</span>])</span><br><span class="line"></span><br><span class="line">            writer.add_summary(summary, iteration)</span><br><span class="line">        writer.close()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">&#x27;__main__&#x27;</span>:</span><br><span class="line">    args = argparser()</span><br><span class="line">    main(args)</span><br><span class="line"></span><br></pre></td></tr></table></figure><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">traj = Trajectory(observations, actions, infos=None, terminal=True)</span><br></pre></td></tr></table></figure><p><strong>–pedestrians</strong></p>]]></content>
    
    <summary type="html">
    
      &lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr.algo &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gail&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;



&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;53&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;54&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;55&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;56&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;57&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;58&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;59&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;60&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;61&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;62&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;63&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;64&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;65&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;66&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;67&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;68&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;69&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;70&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;71&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;72&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;73&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;74&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;75&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;76&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;77&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;78&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;79&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;80&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;81&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;82&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;83&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;84&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;85&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;86&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;87&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;88&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;89&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;90&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;91&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;92&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;93&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;94&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;95&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;96&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;97&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;98&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;99&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;100&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;101&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;102&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;103&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;104&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;105&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;106&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;107&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;108&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;109&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;110&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;111&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;112&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;113&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;114&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;115&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;116&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;117&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;118&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;119&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;120&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;121&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;122&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;123&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;124&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;125&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;126&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;127&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;128&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;129&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;130&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;131&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;132&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;133&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;134&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;135&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;136&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;137&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;138&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;139&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;140&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;141&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;142&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;143&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;144&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;145&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;146&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;147&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;148&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;149&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;150&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;151&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;152&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;153&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;154&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;155&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;156&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;157&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;158&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;159&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;160&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;161&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;162&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;163&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;164&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;165&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;166&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;167&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;168&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;169&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;170&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;171&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;172&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;173&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;174&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;175&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;176&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;177&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;178&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;179&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;180&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;181&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;182&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;183&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;184&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;185&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;186&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;187&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;188&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;189&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;190&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;191&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;192&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;193&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;194&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;195&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;196&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;197&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;198&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; copy&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; glob&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; os&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; time&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; collections &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; deque&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gym&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; numpy &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; np&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch.nn &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; nn&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch.nn.functional &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; F&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch.optim &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; optim&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; algo, utils&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr.algo &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gail&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr.arguments &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; get_args&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr.envs &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; make_vec_envs&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr.model &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; Policy&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; a2c_ppo_acktr.storage &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; RolloutStorage&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; evaluation &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; evaluate&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;main&lt;/span&gt;():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    args = get_args()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    torch.manual_seed(args.seed)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    torch.cuda.manual_seed_all(args.seed)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.cuda &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; torch.cuda.is_available() &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; args.cuda_deterministic:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        torch.backends.cudnn.benchmark = &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        torch.backends.cudnn.deterministic = &lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    log_dir = os.path.expanduser(args.log_dir)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    eval_log_dir = log_dir + &lt;span class=&quot;string&quot;&gt;&amp;quot;_eval&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    utils.cleanup_log_dir(log_dir)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    utils.cleanup_log_dir(eval_log_dir)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    torch.set_num_threads(&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    device = torch.device(&lt;span class=&quot;string&quot;&gt;&amp;quot;cuda:0&amp;quot;&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.cuda &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;string&quot;&gt;&amp;quot;cpu&amp;quot;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    envs = make_vec_envs(args.env_name, args.seed, args.num_processes,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                         args.gamma, args.log_dir, device, &lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    actor_critic = Policy(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        envs.observation_space.shape,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        envs.action_space,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        base_kwargs=&amp;#123;&lt;span class=&quot;string&quot;&gt;&amp;#x27;recurrent&amp;#x27;&lt;/span&gt;: args.recurrent_policy&amp;#125;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    actor_critic.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.algo == &lt;span class=&quot;string&quot;&gt;&amp;#x27;a2c&amp;#x27;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        agent = algo.A2C_ACKTR(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            actor_critic,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.value_loss_coef,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.entropy_coef,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            lr=args.lr,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            eps=args.eps,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            alpha=args.alpha,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            max_grad_norm=args.max_grad_norm)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;elif&lt;/span&gt; args.algo == &lt;span class=&quot;string&quot;&gt;&amp;#x27;ppo&amp;#x27;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        agent = algo.PPO(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            actor_critic,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.clip_param,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.ppo_epoch,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.num_mini_batch,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.value_loss_coef,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.entropy_coef,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            lr=args.lr,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            eps=args.eps,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            max_grad_norm=args.max_grad_norm)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;elif&lt;/span&gt; args.algo == &lt;span class=&quot;string&quot;&gt;&amp;#x27;acktr&amp;#x27;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        agent = algo.A2C_ACKTR(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            actor_critic, args.value_loss_coef, args.entropy_coef, acktr=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.gail:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;len&lt;/span&gt;(envs.observation_space.shape) == &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        discr = gail.Discriminator(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            envs.observation_space.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;] + envs.action_space.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;], &lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        file_name = os.path.join(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            args.gail_experts_dir, &lt;span class=&quot;string&quot;&gt;&amp;quot;trajs_&amp;#123;&amp;#125;.pt&amp;quot;&lt;/span&gt;.&lt;span class=&quot;built_in&quot;&gt;format&lt;/span&gt;(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                args.env_name.split(&lt;span class=&quot;string&quot;&gt;&amp;#x27;-&amp;#x27;&lt;/span&gt;)[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;].lower()))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        expert_dataset = gail.ExpertDataset(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            file_name, num_trajectories=&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;, subsample_frequency=&lt;span class=&quot;number&quot;&gt;20&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        drop_last = &lt;span class=&quot;built_in&quot;&gt;len&lt;/span&gt;(expert_dataset) &amp;gt; args.gail_batch_size&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        gail_train_loader = torch.utils.data.DataLoader(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            dataset=expert_dataset,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            batch_size=args.gail_batch_size,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            shuffle=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            drop_last=drop_last)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    rollouts = RolloutStorage(args.num_steps, args.num_processes,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                              envs.observation_space.shape, envs.action_space,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                              actor_critic.recurrent_hidden_state_size)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    obs = envs.reset()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    rollouts.obs[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;].copy_(obs)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    rollouts.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    episode_rewards = deque(maxlen=&lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    start = time.time()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    num_updates = &lt;span class=&quot;built_in&quot;&gt;int&lt;/span&gt;(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        args.num_env_steps) // args.num_steps // args.num_processes&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; j &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(num_updates):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.use_linear_lr_decay:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# decrease learning rate linearly&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            utils.update_linear_schedule(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                agent.optimizer, j, num_updates,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                agent.optimizer.lr &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.algo == &lt;span class=&quot;string&quot;&gt;&amp;quot;acktr&amp;quot;&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt; args.lr)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; step &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(args.num_steps):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# Sample actions&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;with&lt;/span&gt; torch.no_grad():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                value, action, action_log_prob, recurrent_hidden_states = actor_critic.act(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    rollouts.obs[step], rollouts.recurrent_hidden_states[step],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    rollouts.masks[step])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# Obser reward and next obs&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            obs, reward, done, infos = envs.step(action)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; info &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; infos:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;string&quot;&gt;&amp;#x27;episode&amp;#x27;&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; info.keys():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    episode_rewards.append(info[&lt;span class=&quot;string&quot;&gt;&amp;#x27;episode&amp;#x27;&lt;/span&gt;][&lt;span class=&quot;string&quot;&gt;&amp;#x27;r&amp;#x27;&lt;/span&gt;])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# If done then clean the history of observations.&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            masks = torch.FloatTensor(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                [[&lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;] &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; done_ &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt; [&lt;span class=&quot;number&quot;&gt;1.0&lt;/span&gt;] &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; done_ &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; done])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            bad_masks = torch.FloatTensor(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                [[&lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;] &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;string&quot;&gt;&amp;#x27;bad_transition&amp;#x27;&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; info.keys() &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt; [&lt;span class=&quot;number&quot;&gt;1.0&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                 &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; info &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; infos])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            rollouts.insert(obs, recurrent_hidden_states, action,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                            action_log_prob, value, reward, masks, bad_masks)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;with&lt;/span&gt; torch.no_grad():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            next_value = actor_critic.get_value(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                rollouts.obs[-&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;], rollouts.recurrent_hidden_states[-&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                rollouts.masks[-&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;]).detach()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; args.gail:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; j &amp;gt;= &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                envs.venv.&lt;span class=&quot;built_in&quot;&gt;eval&lt;/span&gt;()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            gail_epoch = args.gail_epoch&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; j &amp;lt; &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                gail_epoch = &lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;  &lt;span class=&quot;comment&quot;&gt;# Warm up&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; _ &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(gail_epoch):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                discr.update(gail_train_loader, rollouts,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                             utils.get_vec_normalize(envs)._obfilt)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; step &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(args.num_steps):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                rollouts.rewards[step] = discr.predict_reward(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    rollouts.obs[step], rollouts.actions[step], args.gamma,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    rollouts.masks[step])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        rollouts.compute_returns(next_value, args.use_gae, args.gamma,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                 args.gae_lambda, args.use_proper_time_limits)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        value_loss, action_loss, dist_entropy = agent.update(rollouts)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        rollouts.after_update()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# save for every interval-th episode or for the last epoch&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; (j % args.save_interval == &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;or&lt;/span&gt; j == num_updates - &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; args.save_dir != &lt;span class=&quot;string&quot;&gt;&amp;quot;&amp;quot;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            save_path = os.path.join(args.save_dir, args.algo)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;try&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                os.makedirs(save_path)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;except&lt;/span&gt; OSError:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;pass&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            torch.save([&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                actor_critic,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;built_in&quot;&gt;getattr&lt;/span&gt;(utils.get_vec_normalize(envs), &lt;span class=&quot;string&quot;&gt;&amp;#x27;obs_rms&amp;#x27;&lt;/span&gt;, &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            ], os.path.join(save_path, args.env_name + &lt;span class=&quot;string&quot;&gt;&amp;quot;.pt&amp;quot;&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; j % args.log_interval == &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;len&lt;/span&gt;(episode_rewards) &amp;gt; &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            total_num_steps = (j + &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) * args.num_processes * args.num_steps&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            end = time.time()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;string&quot;&gt;&amp;quot;Updates &amp;#123;&amp;#125;, num timesteps &amp;#123;&amp;#125;, FPS &amp;#123;&amp;#125; &#92;n Last &amp;#123;&amp;#125; training episodes: mean/median reward &amp;#123;:.1f&amp;#125;/&amp;#123;:.1f&amp;#125;, min/max reward &amp;#123;:.1f&amp;#125;/&amp;#123;:.1f&amp;#125;&#92;n&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                .&lt;span class=&quot;built_in&quot;&gt;format&lt;/span&gt;(j, total_num_steps,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        &lt;span class=&quot;built_in&quot;&gt;int&lt;/span&gt;(total_num_steps / (end - start)),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        &lt;span class=&quot;built_in&quot;&gt;len&lt;/span&gt;(episode_rewards), np.mean(episode_rewards),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        np.median(episode_rewards), np.&lt;span class=&quot;built_in&quot;&gt;min&lt;/span&gt;(episode_rewards),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        np.&lt;span class=&quot;built_in&quot;&gt;max&lt;/span&gt;(episode_rewards), dist_entropy, value_loss,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        action_loss))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; (args.eval_interval &lt;span class=&quot;keyword&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;len&lt;/span&gt;(episode_rewards) &amp;gt; &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; j % args.eval_interval == &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            obs_rms = utils.get_vec_normalize(envs).obs_rms&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            evaluate(actor_critic, obs_rms, args.env_name, args.seed,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                     args.num_processes, eval_log_dir, device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; __name__ == &lt;span class=&quot;string&quot;&gt;&amp;quot;__main__&amp;quot;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    main()&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;



&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;53&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;54&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;55&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;56&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;57&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;58&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;59&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;60&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;61&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;62&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;63&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;64&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;65&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;66&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;67&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;68&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;69&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;70&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;71&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;72&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;73&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;74&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;75&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;76&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;77&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;78&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;79&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;80&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;81&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;82&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;83&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;84&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;85&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;86&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;87&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;88&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;89&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;90&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;91&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;92&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;93&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;94&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;95&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;96&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;97&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;98&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;99&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;100&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;101&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;102&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;103&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;104&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;105&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;106&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;107&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;108&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;109&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;110&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;111&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;112&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;113&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;114&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;115&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;116&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;117&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;118&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;119&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;120&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;121&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;122&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;123&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;124&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;125&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;126&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;127&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;128&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;129&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;130&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;131&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;132&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;133&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;134&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;135&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;136&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;137&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;138&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;139&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;140&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;141&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;142&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;143&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;144&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;145&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;146&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;147&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;148&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;149&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;150&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;151&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;152&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;153&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;154&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;155&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;156&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;157&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;158&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;159&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;160&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;161&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;162&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;163&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;164&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;165&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;166&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; h5py&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; numpy &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; np&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch.nn &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; nn&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch.nn.functional &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; F&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch.utils.data&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; torch &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; autograd&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3.common.running_mean_std &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; RunningMeanStd&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;Discriminator&lt;/span&gt;(nn.Module):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, input_dim, hidden_dim, device&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;super&lt;/span&gt;(Discriminator, &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;).__init__()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.device = device&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(input_dim, hidden_dim), nn.Tanh(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(hidden_dim, hidden_dim), nn.Tanh(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(hidden_dim, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;)).to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk.train()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.optimizer = torch.optim.Adam(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk.parameters())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.returns = &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.ret_rms = RunningMeanStd(shape=())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;compute_grad_pen&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self,&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;params&quot;&gt;                         expert_state,&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;params&quot;&gt;                         expert_action,&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;params&quot;&gt;                         policy_state,&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;params&quot;&gt;                         policy_action,&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;params&quot;&gt;                         lambda_=&lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        alpha = torch.rand(expert_state.size(&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;), &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        expert_data = torch.cat([expert_state, expert_action], dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        policy_data = torch.cat([policy_state, policy_action], dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        alpha = alpha.expand_as(expert_data).to(expert_data.device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        mixup_data = alpha * expert_data + (&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt; - alpha) * policy_data&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        mixup_data.requires_grad = &lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        disc = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk(mixup_data)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ones = torch.ones(disc.size()).to(disc.device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        grad = autograd.grad(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            outputs=disc,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            inputs=mixup_data,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            grad_outputs=ones,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            create_graph=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            retain_graph=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            only_inputs=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;)[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        grad_pen = lambda_ * (grad.norm(&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;, dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) - &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;).&lt;span class=&quot;built_in&quot;&gt;pow&lt;/span&gt;(&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;).mean()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; grad_pen&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;update&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, expert_loader, rollouts, obsfilt=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.train()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        policy_data_generator = rollouts.feed_forward_generator(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;, mini_batch_size=expert_loader.batch_size)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        loss = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        n = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; expert_batch, policy_batch &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;zip&lt;/span&gt;(expert_loader,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                              policy_data_generator):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            policy_state, policy_action = policy_batch[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;], policy_batch[&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            policy_d = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                torch.cat([policy_state, policy_action], dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            expert_state, expert_action = expert_batch&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            expert_state = obsfilt(expert_state.numpy(), update=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            expert_state = torch.FloatTensor(expert_state).to(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            expert_action = expert_action.to(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            expert_d = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                torch.cat([expert_state, expert_action], dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            expert_loss = F.binary_cross_entropy_with_logits(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                expert_d,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                torch.ones(expert_d.size()).to(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.device))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            policy_loss = F.binary_cross_entropy_with_logits(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                policy_d,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                torch.zeros(policy_d.size()).to(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.device))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            gail_loss = expert_loss + policy_loss&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            grad_pen = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.compute_grad_pen(expert_state, expert_action,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                             policy_state, policy_action)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            loss += (gail_loss + grad_pen).item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            n += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.optimizer.zero_grad()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            (gail_loss + grad_pen).backward()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.optimizer.step()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; loss / n&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;predict_reward&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, state, action, gamma, masks, update_rms=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;with&lt;/span&gt; torch.no_grad():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.&lt;span class=&quot;built_in&quot;&gt;eval&lt;/span&gt;()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            d = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trunk(torch.cat([state, action], dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            s = torch.sigmoid(d)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            reward = s.log() - (&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt; - s).log()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.returns &lt;span class=&quot;keyword&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.returns = reward.clone()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; update_rms:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.returns = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.returns * masks * gamma + reward&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.ret_rms.update(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.returns.cpu().numpy())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; reward / np.sqrt(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.ret_rms.var[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;] + &lt;span class=&quot;number&quot;&gt;1e-8&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;ExpertDataset&lt;/span&gt;(torch.utils.data.Dataset):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, file_name, num_trajectories=&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;, subsample_frequency=&lt;span class=&quot;number&quot;&gt;20&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        all_trajectories = torch.load(file_name)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        perm = torch.randperm(all_trajectories[&lt;span class=&quot;string&quot;&gt;&amp;#x27;states&amp;#x27;&lt;/span&gt;].size(&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        idx = perm[:num_trajectories]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories = &amp;#123;&amp;#125;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# See https://github.com/pytorch/pytorch/issues/14886&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# .long() for fixing bug in torch v0.4.1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        start_idx = torch.randint(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, subsample_frequency, size=(num_trajectories, )).long()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; k, v &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; all_trajectories.items():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            data = v[idx]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; k != &lt;span class=&quot;string&quot;&gt;&amp;#x27;lengths&amp;#x27;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                samples = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; i &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(num_trajectories):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    samples.append(data[i, start_idx[i]::subsample_frequency])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[k] = torch.stack(samples)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[k] = data // subsample_frequency&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.i2traj_idx = &amp;#123;&amp;#125;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.i2i = &amp;#123;&amp;#125;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.length = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[&lt;span class=&quot;string&quot;&gt;&amp;#x27;lengths&amp;#x27;&lt;/span&gt;].&lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        traj_idx = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        i = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.get_idx = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; j &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.length):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[&lt;span class=&quot;string&quot;&gt;&amp;#x27;lengths&amp;#x27;&lt;/span&gt;][traj_idx].item() &amp;lt;= i:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                i -= &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[&lt;span class=&quot;string&quot;&gt;&amp;#x27;lengths&amp;#x27;&lt;/span&gt;][traj_idx].item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                traj_idx += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.get_idx.append((traj_idx, i))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            i += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__len__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.length&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__getitem__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, i&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        traj_idx, i = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.get_idx[i]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[&lt;span class=&quot;string&quot;&gt;&amp;#x27;states&amp;#x27;&lt;/span&gt;][traj_idx][i], &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.trajectories[&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;string&quot;&gt;&amp;#x27;actions&amp;#x27;&lt;/span&gt;][traj_idx][i]&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;











&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; numpy &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; np&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gym&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3 &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; PPO&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3.common.evaluation &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; evaluate_policy&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3.common.vec_env &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; DummyVecEnv&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; stable_baselines3.ppo &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; MlpPolicy&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; imitation.algorithms.adversarial.gail &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; GAIL&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; imitation.data &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; rollout&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; imitation.data.wrappers &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; RolloutInfoWrapper&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; imitation.rewards.reward_nets &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; BasicRewardNet&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; imitation.util.networks &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; RunningNorm&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; imitation.util.util &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; make_vec_env&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;rng = np.random.default_rng(&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;env = gym.make(&lt;span class=&quot;string&quot;&gt;&amp;quot;seals/CartPole-v0&amp;quot;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;expert = PPO(policy=MlpPolicy, env=env, n_steps=&lt;span class=&quot;number&quot;&gt;64&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;expert.learn(&lt;span class=&quot;number&quot;&gt;1000&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;rollouts = rollout.rollout(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    expert,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    make_vec_env(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;string&quot;&gt;&amp;quot;seals/CartPole-v0&amp;quot;&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        n_envs=&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        post_wrappers=[&lt;span class=&quot;keyword&quot;&gt;lambda&lt;/span&gt; env, _: RolloutInfoWrapper(env)],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        rng=rng,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    rollout.make_sample_until(min_timesteps=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;, min_episodes=&lt;span class=&quot;number&quot;&gt;60&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    rng=rng,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;venv = make_vec_env(&lt;span class=&quot;string&quot;&gt;&amp;quot;seals/CartPole-v0&amp;quot;&lt;/span&gt;, n_envs=&lt;span class=&quot;number&quot;&gt;8&lt;/span&gt;, rng=rng)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;learner = PPO(env=venv, policy=MlpPolicy)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;reward_net = BasicRewardNet(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    venv.observation_space,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    venv.action_space,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    normalize_input_layer=RunningNorm,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;gail_trainer = GAIL(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    demonstrations=rollouts,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    demo_batch_size=&lt;span class=&quot;number&quot;&gt;1024&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    gen_replay_buffer_capacity=&lt;span class=&quot;number&quot;&gt;2048&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    n_disc_updates_per_round=&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    venv=venv,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    gen_algo=learner,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    reward_net=reward_net,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;gail_trainer.train(&lt;span class=&quot;number&quot;&gt;20000&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;rewards, _ = evaluate_policy(learner, venv, &lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;, return_episode_rewards=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;quot;Rewards:&amp;quot;&lt;/span&gt;, rewards)&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;



&lt;p&gt;TensorFlow&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;53&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;54&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;55&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;56&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;57&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;58&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;59&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;60&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;61&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;62&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;63&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;64&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;65&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;66&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;67&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;68&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;69&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;70&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;71&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;72&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;73&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;74&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;75&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;76&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;77&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;78&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;79&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;80&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;81&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;82&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;83&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;84&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;85&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;86&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;87&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;88&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;89&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;90&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;91&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;92&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;93&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;94&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;95&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;96&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;97&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;98&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;99&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;100&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;101&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;102&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;103&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;104&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;105&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;106&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;107&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;108&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;109&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;110&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;111&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;112&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;113&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;114&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;115&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;116&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;117&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;118&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;119&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;120&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;121&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;122&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;123&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;124&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;125&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;126&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;127&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;128&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;129&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; argparse&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; gym&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; numpy &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; np&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; tensorflow &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; tf&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; network_models.policy_net &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; Policy_net&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; network_models.discriminator &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; Discriminator&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; algo.ppo &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; PPOTrain&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;argparser&lt;/span&gt;():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    parser = argparse.ArgumentParser()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    parser.add_argument(&lt;span class=&quot;string&quot;&gt;&amp;#x27;--logdir&amp;#x27;&lt;/span&gt;, &lt;span class=&quot;built_in&quot;&gt;help&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;#x27;log directory&amp;#x27;&lt;/span&gt;, default=&lt;span class=&quot;string&quot;&gt;&amp;#x27;log/train/gail&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    parser.add_argument(&lt;span class=&quot;string&quot;&gt;&amp;#x27;--savedir&amp;#x27;&lt;/span&gt;, &lt;span class=&quot;built_in&quot;&gt;help&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;#x27;save directory&amp;#x27;&lt;/span&gt;, default=&lt;span class=&quot;string&quot;&gt;&amp;#x27;trained_models/gail&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    parser.add_argument(&lt;span class=&quot;string&quot;&gt;&amp;#x27;--gamma&amp;#x27;&lt;/span&gt;, default=&lt;span class=&quot;number&quot;&gt;0.95&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    parser.add_argument(&lt;span class=&quot;string&quot;&gt;&amp;#x27;--iteration&amp;#x27;&lt;/span&gt;, default=&lt;span class=&quot;built_in&quot;&gt;int&lt;/span&gt;(&lt;span class=&quot;number&quot;&gt;1e4&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; parser.parse_args()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;main&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;args&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    env = gym.make(&lt;span class=&quot;string&quot;&gt;&amp;#x27;CartPole-v0&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    env.seed(&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ob_space = env.observation_space&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    Policy = Policy_net(&lt;span class=&quot;string&quot;&gt;&amp;#x27;policy&amp;#x27;&lt;/span&gt;, env)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    Old_Policy = Policy_net(&lt;span class=&quot;string&quot;&gt;&amp;#x27;old_policy&amp;#x27;&lt;/span&gt;, env)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    PPO = PPOTrain(Policy, Old_Policy, gamma=args.gamma)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    D = Discriminator(env)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;comment&quot;&gt;# 得到专家的观测和行动&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    expert_observations = np.genfromtxt(&lt;span class=&quot;string&quot;&gt;&amp;#x27;trajectory/observations.csv&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    expert_actions = np.genfromtxt(&lt;span class=&quot;string&quot;&gt;&amp;#x27;trajectory/actions.csv&amp;#x27;&lt;/span&gt;, dtype=np.int32)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    saver = tf.train.Saver()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;with&lt;/span&gt; tf.Session() &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; sess:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        writer = tf.summary.FileWriter(args.logdir, sess.graph)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        sess.run(tf.global_variables_initializer())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        obs = env.reset()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        success_num = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; iteration &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(args.iteration):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            observations = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            actions = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            rewards = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            v_preds = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            run_policy_steps = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                run_policy_steps += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                obs = np.stack([obs]).astype(dtype=np.float32)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                act, v_pred = Policy.act(obs = obs,stochastic = &lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                act = np.asscalar(act)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                v_pred = np.asscalar(v_pred)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                next_obs,reward,done,info = env.step(act)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                observations.append(obs)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                actions.append(act)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                rewards.append(reward)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                v_preds.append(v_pred)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; done:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    next_obs = np.stack([next_obs]).astype(dtype=np.float32)  &lt;span class=&quot;comment&quot;&gt;# prepare to feed placeholder Policy.obs&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    _, v_pred = Policy.act(obs=next_obs, stochastic=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    v_preds_next = v_preds[&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;:] + [np.asscalar(v_pred)]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    obs = env.reset()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    &lt;span class=&quot;keyword&quot;&gt;break&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    obs = next_obs&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=&lt;span class=&quot;string&quot;&gt;&amp;#x27;episode_length&amp;#x27;&lt;/span&gt;, simple_value=run_policy_steps)])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                               , iteration)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=&lt;span class=&quot;string&quot;&gt;&amp;#x27;episode_reward&amp;#x27;&lt;/span&gt;, simple_value=&lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;(rewards))])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                               , iteration)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;(rewards) &amp;gt;= &lt;span class=&quot;number&quot;&gt;195&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                success_num += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; success_num &amp;gt;= &lt;span class=&quot;number&quot;&gt;100&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    saver.save(sess, args.savedir + &lt;span class=&quot;string&quot;&gt;&amp;#x27;/model.ckpt&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;#x27;Clear!! Model saved.&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                    &lt;span class=&quot;keyword&quot;&gt;break&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                success_num = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            observations = np.reshape(observations,newshape=[-&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;] + &lt;span class=&quot;built_in&quot;&gt;list&lt;/span&gt;(ob_space.shape))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            actions = np.array(actions).astype(dtype = np.int32)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; i &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                D.train(expert_s = expert_observations,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        expert_a = expert_actions,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        agent_s = observations,&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                        agent_a = actions)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            d_rewards = D.get_rewards(agent_s=observations,agent_a = actions)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            d_rewards = np.reshape(d_rewards,newshape=[-&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;]).astype(dtype=np.float32)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            gaes = PPO.get_gaes(rewards=d_rewards, v_preds=v_preds, v_preds_next=v_preds_next)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            gaes = np.array(gaes).astype(dtype=np.float32)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# gaes = (gaes - gaes.mean()) / gaes.std()&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            v_preds_next = np.array(v_preds_next).astype(dtype=np.float32)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# train policy&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            inp = [observations, actions, gaes, d_rewards, v_preds_next]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            PPO.assign_policy_parameters()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; epoch &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(&lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                sample_indices = np.random.randint(low=&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, high=observations.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                                   size=&lt;span class=&quot;number&quot;&gt;32&lt;/span&gt;)  &lt;span class=&quot;comment&quot;&gt;# indices are in [low, high)&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                sampled_inp = [np.take(a=a, indices=sample_indices, axis=&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;) &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; a &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; inp]  &lt;span class=&quot;comment&quot;&gt;# sample training data&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                PPO.train(obs=sampled_inp[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                          actions=sampled_inp[&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                          gaes=sampled_inp[&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                          rewards=sampled_inp[&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                          v_preds_next=sampled_inp[&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            summary = PPO.get_summary(obs=inp[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                      actions=inp[&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                      gaes=inp[&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                      rewards=inp[&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;],&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;                                      v_preds_next=inp[&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;])&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            writer.add_summary(summary, iteration)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        writer.close()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; __name__ == &lt;span class=&quot;string&quot;&gt;&amp;#x27;__main__&amp;#x27;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    args = argparser()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    main(args)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;figure class=&quot;highlight plaintext&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;traj = Trajectory(observations, actions, infos=None, terminal=True)&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;&lt;strong&gt;–pedestrians&lt;/strong&gt;&lt;/p&gt;

    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/3660843287.html"/>
    <id>https://blog.aivgg.com/posts/3660843287.html</id>
    <published>2026-06-10T16:18:37.526Z</published>
    <updated>2026-06-10T16:24:04.382Z</updated>
    
    <content type="html"><![CDATA[<h3 id="概念"><a href="#概念" class="headerlink" title="概念"></a>概念</h3><p>Docker包括三个基本概念：</p><p>镜像（<code>Image</code>）：Docker 镜像是一个特殊的文件系统，除了提供容器运行时所需的程序、库、资源、配置等文件外，还包含了一些为运行时准备的一些配置参数（如匿名卷、环境变量、用户等）。镜像不包含任何动态数据，其内容在构建之后也不会被改变。  </p><p>容器（<code>Container</code>）：镜像（<code>Image</code>）和容器（<code>Container</code>）的关系，就<strong>像是面向对象程序设计中的 <code>类</code>和 <code>实例</code> 一样，镜像是静态的定义，容器是镜像运行时的实体</strong>。容器可以被创建、启动、停止、删除、暂停等。 </p><p>仓库（<code>Repository</code>）：仓库（<code>Repository</code>）类似Git的远程仓库，<strong>集中存放镜像文件</strong>。 </p><p>三者关系可以用下图表示：</p><p><img src="https://img-blog.csdnimg.cn/72f22c291cba4217bbf0ba1f4432d33d.png" alt="在这里插入图片描述"></p><p>Docker的常用命令</p><h3 id="服务"><a href="#服务" class="headerlink" title="服务"></a>服务</h3><p>查看Docker版本信息</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker version</span><br></pre></td></tr></table></figure><p>启动Docker</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">systemctl start docker</span><br></pre></td></tr></table></figure><p>关闭docker</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">systemctl stop docker</span><br></pre></td></tr></table></figure><p>设置开机启动</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">systemctl enable docker</span><br></pre></td></tr></table></figure><p>重启docker服务</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">service docker restart</span><br></pre></td></tr></table></figure><p>关闭docker服务</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">service docker stop</span><br></pre></td></tr></table></figure><h3 id="镜像"><a href="#镜像" class="headerlink" title="镜像"></a>镜像</h3><p>Docker Hub等镜像仓库上有大量的高质量的镜像可以用，可以从仓库获取镜像</p><p>检索镜像</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker search 关键字</span><br></pre></td></tr></table></figure><p>拉取镜像</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker pull [选项] [<span class="title class_">Docker</span> <span class="title class_">Registry</span> 地址[:端口号]/]仓库名[:标签]</span><br></pre></td></tr></table></figure><p>列出镜像</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">docker image ls</span><br><span class="line">docker images</span><br></pre></td></tr></table></figure><p>删除镜像</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"># 删除指定镜像</span><br><span class="line">docker rmi &lt;镜像<span class="title class_">Id</span>&gt;</span><br></pre></td></tr></table></figure><p>导出镜像</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"># 将镜像保存为归档文件</span><br><span class="line">docker save</span><br></pre></td></tr></table></figure><p>导入镜像</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker load</span><br></pre></td></tr></table></figure><p>构建镜像</p><p>Dockerfile 是一个文本格式的配置文件，用户可以使用 Dockerfile 来快速创建自定义的镜像</p><p>Dockerfile 由一行行行命令语句组成，并且支持以＃开头的注释行.</p><p><strong>一些常见的指令</strong>：</p><ul><li>FROM：指定基础镜像 </li><li>RUN：执行命令 </li><li>COPY：复制文件 </li><li>ADD：更高级的复制文件 </li><li>CMD：容器启动命令 </li><li>ENV：设置环境变量 </li><li>EXPOSE：暴露端口</li></ul><p>以下是一个Dockerfile实例：</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="variable constant_">FROM</span> <span class="attr">java</span>:<span class="number">8</span></span><br><span class="line"><span class="variable constant_">MAINTAINER</span> <span class="string">&quot;jinshw&quot;</span>&lt;jinshw@qq.<span class="property">com</span>&gt;</span><br><span class="line"><span class="variable constant_">ADD</span> mapcharts-<span class="number">0.0</span><span class="number">.1</span>-<span class="variable constant_">SNAPSHOT</span>.<span class="property">jar</span> mapcharts.<span class="property">jar</span></span><br><span class="line"><span class="variable constant_">EXPOSE</span> <span class="number">8080</span></span><br><span class="line"><span class="variable constant_">CMD</span> java -jar mapcharts.<span class="property">jar</span></span><br></pre></td></tr></table></figure><p>构建镜像命令：</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker build [选项] &lt;上下文路径/<span class="variable constant_">URL</span>/-&gt;</span><br></pre></td></tr></table></figure><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"># 指定<span class="title class_">Dockerfile</span>路径</span><br><span class="line">docker build -f /path/to/a/<span class="title class_">Dockerfile</span> .</span><br><span class="line"># 默认使用当前路径中的<span class="title class_">Dockerfile</span>文件 </span><br><span class="line">docker build . </span><br><span class="line"># -t，--tag 指定构建的镜像名和tag</span><br><span class="line">docker build -t image-<span class="attr">nginx</span>:v1 . </span><br></pre></td></tr></table></figure><p>镜像运行 （容器运行）:</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"># 新建并启动</span><br><span class="line">docker run [镜像<span class="variable constant_">ID</span>]</span><br><span class="line"># 启动已终止容器</span><br><span class="line">docker start [容器<span class="variable constant_">ID</span>]</span><br></pre></td></tr></table></figure><h3 id="容器"><a href="#容器" class="headerlink" title="容器"></a>容器</h3><p>启动容器有两种方式，一种是基于镜像新建一个容器并启动，另外一个是将在终止状态（stopped）的容器重新启动。</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker run --name mynginx -it 01da99b6476c</span><br></pre></td></tr></table></figure><p> -it 参数 为该docker创建一个伪终端，这样就可以进入到容器的交互模式</p><p>查看容器</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"># 列出本机运行的容器</span><br><span class="line">$ docker ps </span><br><span class="line"># 列出本机所有的容器（包括停止和运行）</span><br><span class="line">$ docker ps -a</span><br></pre></td></tr></table></figure><p> 停止容器</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"># 停止运行的容器</span><br><span class="line">docker stop [容器<span class="variable constant_">ID</span>]</span><br><span class="line"># 杀死容器进程</span><br><span class="line">docker  kill [容器<span class="variable constant_">ID</span>] </span><br></pre></td></tr></table></figure><p>重启容器</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker restart [容器<span class="variable constant_">ID</span>] </span><br></pre></td></tr></table></figure><p>删除容器</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker  rm [容器<span class="variable constant_">ID</span>]</span><br></pre></td></tr></table></figure><p>进入容器</p><figure class="highlight javascript"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"># 如果从这个 stdin 中 exit，会导致容器的停止</span><br><span class="line">docker attach [容器<span class="variable constant_">ID</span>]</span><br><span class="line"># 交互式进入容器</span><br><span class="line">docker exec [容器<span class="variable constant_">ID</span>]</span><br></pre></td></tr></table></figure><p><code>docker exec</code>后面跟的常见参数如下：</p><blockquote><p>－ d, –detach 在容器中后台执行命令</p><p> － i, –interactive&#x3D;true I false ：打开标准输入接受用户输入命令</p></blockquote><h3 id="发布"><a href="#发布" class="headerlink" title="发布"></a>发布</h3><p>容器运行成功后，就确认了 image 文件的有效性。这时，我们就可以考虑把 image 文件分享到网上，让其他人使用。</p><p>去 hub.docker.com 或 cloud.docker.com注册一个账户</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker login</span><br></pre></td></tr></table></figure><p>为本地的 image 标注用户名和版本</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker image tag koa-demos:0.0.1 ruanyf/koa-demos:0.0.1</span><br></pre></td></tr></table></figure><p>最后，发布 image 文件。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker image push [username]/[repository]:[tag]</span><br></pre></td></tr></table></figure><h3 id="Demo"><a href="#Demo" class="headerlink" title="Demo"></a>Demo</h3><p>从零开始来创建一个新的镜像：</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">FROM registry.cn-beijing.aliyuncs.com/selfdriveguard/ubuntu_20_sumo:1.14.1</span><br><span class="line"></span><br><span class="line">COPY ./requirements.txt /oasis_sumo/requirements.txt</span><br><span class="line"></span><br><span class="line">RUN sed -i &#x27;s/archive.ubuntu.com/mirrors.aliyun.com/g&#x27; /etc/apt/sources.list</span><br><span class="line">RUN apt-get update &amp;&amp; apt-get install -y python3.8 python3-pip &amp;&amp; apt-get install libgeos-dev --assume-yes</span><br><span class="line">RUN python3.8 -m pip install -i https://mirrors.aliyun.com/pypi/simple pip -U &amp;&amp; \</span><br><span class="line">    python3.8 -m pip config set global.index-url https://mirrors.aliyun.com/pypi/simple &amp;&amp; \</span><br><span class="line">    python3.8 -m pip install --no-cache-dir --upgrade -I -r /oasis_sumo/requirements.txt</span><br><span class="line">    python3.8 -m pip install eventlet==0.33.1</span><br></pre></td></tr></table></figure><p>使用 Dockerfile 文件，通过 docker build 命令来构建一个镜像</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"># 作为基础镜像</span><br><span class="line"># docker build -t oasis-sumo-base:v0.1 .</span><br></pre></td></tr></table></figure><p> 查看创建的镜像</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker images</span><br></pre></td></tr></table></figure><p>已经在列表中存在,镜像ID为860c279d2fec <strong>使用新的镜像来创建容器</strong></p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">docker run -it oasis-sumo-base:v0.1</span><br><span class="line"># 或者</span><br><span class="line">docker run -it 860c279d2fec</span><br></pre></td></tr></table></figure><p>查看容器运行情况</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker ps</span><br></pre></td></tr></table></figure><p>进入服务器docker查看日志，启动交互式的 bash shell</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">docker exec -it 860c279d2fec bash</span><br></pre></td></tr></table></figure><p><strong>镜像可以简单理解为自定义的运行环境，容器为实例化后的镜像，运行中的进程</strong></p>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;概念&quot;&gt;&lt;a href=&quot;#概念&quot; class=&quot;headerlink&quot; title=&quot;概念&quot;&gt;&lt;/a&gt;概念&lt;/h3&gt;&lt;p&gt;Docker包括三个基本概念：&lt;/p&gt;
&lt;p&gt;镜像（&lt;code&gt;Image&lt;/code&gt;）：Docker 镜像是一个特殊的文件系统，除了提供容器运行时所需的程序、库、资源、配置等文件外，还包含了一些为运行时准备的一些配置参数（如匿名卷、环境变量、用户等）。镜像不包含任何动态数据，其内容在构建之后也不会被改变。  &lt;/p&gt;
&lt;p&gt;容器（&lt;code&gt;Container&lt;/code&gt;）：镜像（&lt;code&gt;Image&lt;/code&gt;）和容器（&lt;code&gt;Container&lt;/code&gt;）的关系，就&lt;strong&gt;像是面向对象程序设计中的 &lt;code&gt;类&lt;/code&gt;和 &lt;code&gt;实例&lt;/code&gt; 一样，镜像是静态的定义，容器是镜像运行时的实体&lt;/strong&gt;。容器可以被创建、启动、停止、删除、暂停等。 &lt;/p&gt;
&lt;p&gt;仓库（&lt;code&gt;Repository&lt;/code&gt;）：仓库（&lt;code&gt;Repository&lt;/code&gt;）类似Git的远程仓库，&lt;strong&gt;集中存放镜像文件&lt;/strong&gt;。 &lt;/p&gt;
&lt;p&gt;三者关系可以用下图表示：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/72f22c291cba4217bbf0ba1f4432d33d.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;Docker的常用命令&lt;/p&gt;
&lt;h3 id=&quot;服务&quot;&gt;&lt;a href=&quot;#服务&quot; class=&quot;headerlink&quot; title=&quot;服务&quot;&gt;&lt;/a&gt;服务&lt;/h3&gt;&lt;p&gt;查看Docker版本信息&lt;/p&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/3684848571.html"/>
    <id>https://blog.aivgg.com/posts/3684848571.html</id>
    <published>2026-06-10T16:18:37.525Z</published>
    <updated>2026-06-10T16:24:04.382Z</updated>
    
    <content type="html"><![CDATA[<h3 id="背景"><a href="#背景" class="headerlink" title="背景"></a>背景</h3><p>自动驾驶算法的调试和效果评测首先要在仿真环境中去做，因此，一个强大、灵活的仿真环境是开发、测试过程中必不可少的要素。我在查找可用的仿真工具时主要关注以下几个特性：</p><ul><li>开源，免费</li><li>包含高速场景</li><li>可以便捷的控制、切换场景的环境，场景、环境尽可能的丰富与真实</li><li>可以便捷的控制、切换场景内移动物体（如车辆、行人等）的行为模式，行为模式尽可能的丰富与真实</li></ul><p>CARLA是一个主要由英特尔实验室和巴塞罗那的计算机视觉中心开发的开源项目。包含了自动驾驶系统的3种方法：</p><ul><li>经典模块化方法，包含基于视觉的感知模块，基于规则的规划器，还有行为控制器。</li><li>端到端的<strong>模仿学习</strong>方法。</li><li>端到端的<strong>强化学习</strong>方法。</li></ul><p>特点：</p><ul><li>针对3D的城市场景（urban driving），开源免费，支持感知、规划、控制。</li><li>基于Unreal Engine 4，server-client结构。</li><li>可安装于Linux和Windows。</li><li>Python API，没有C++ API。</li></ul><h3 id="架构"><a href="#架构" class="headerlink" title="架构"></a>架构</h3><p>Carla 是一个开源的模拟器，可以模拟真实的交通环境，行人行为，汽车传感器信号等等。如下图所示，模拟器使用C++ 和<strong>虚幻 (Unreal) 引擎</strong>构成，使用者可以通过 Python API 使用 Python 脚本代码对模拟器的环境进行操作和控制。</p><p><img src="https://img-blog.csdnimg.cn/b80d4acc1aec462a9f21b8a7c76edae6.png" alt="在这里插入图片描述"></p><p>Carla主要分为<strong>Server与Client两个模块</strong>，Server端用来建立这个<strong>仿真世界</strong>，而Client端则是由<strong>用户控制</strong>，用来调整、变化这个仿真世界。</p><ul><li>Server: Server端负责任何与<strong>仿真本身相关</strong>的事情：从3D渲染汽车、街道、建筑，传感器模型的构建，到物理计算等等。它就像一个<strong>造物主，</strong> 将整个世界建造出来，并且根据Client 的外来指令更新这个世界。它本身是基于UnrealEnigne做出的3D渲染。</li><li>Client: 如果server构造了整个世界，那么这个世界不同时刻到底该<strong>如何运转</strong>（比如天气是什么样，有多少辆车在跑，速度是多少）则是由Client端<strong>控制</strong>的。用户通过书写Python脚本（最新版本C++ 也可以）来向Server端输送指令指导世界的变化，Server根据用户的指令去执行。Client端<strong>也可以接受Server端的信息</strong>，譬如某个照相机拍到的路面图片。</li></ul><h3 id="核心模块"><a href="#核心模块" class="headerlink" title="核心模块"></a>核心模块</h3><ol><li><strong>Traffic Manager</strong>: 自动驾驶之所以难搞，很核心的一个原因就是现实世界车太多了！因此，Carla专门构造了<strong>Traffic Manager这个模块来模拟类似现实世界负责的交通环境</strong>。通过这个模块，用户可以定义N多不同车型、不同行为模式、不同速度的车辆在路上愉快地与你的自动驾驶汽车（Ego-Vehicle）一起玩耍。</li><li><strong>Sensors:</strong> Carla里面有各种各样<strong>模拟真实世界的传感器模型</strong>，包括相机、激光雷达、声波雷达、IMU、GNSS等等。为了让仿真更接近真实世界，它里面的相机拍出的照片甚至还有畸变和动态模糊效果。用户一般将这些Sensor attach到不同的车辆上来收集各种数据。</li><li><strong>Recorder：</strong> 俗话说的好，不能复现的仿真不是好仿真。这个模块就是用来<strong>记录仿真每一个时刻</strong>（Step)的状态，可以用来<strong>回顾、复现</strong>等等。</li><li><strong>ROS bridge：</strong> 这个模块可以让Carla与ROS还有Autoware交互，正是这个模块的存在使得在仿真里<strong>测试你的自动驾驶系统变得可能</strong>，十分重要，后面也会详细讲解。</li><li><strong>Open Assest</strong>：这个模块可以允许你为仿真世界<strong>添加customized的物体库</strong>，比如你可以在默认的汽车蓝图里再加一个真实世界不存在、外形酷炫的小飞汽车，用来给Client端调用。</li></ol><h3 id="环境"><a href="#环境" class="headerlink" title="环境"></a>环境</h3><p>创建城市环境的3个步骤：</p><ol><li>排布道路和人行道</li><li>放置房子、绿植、地形、交通基础设施</li><li>指定动态物体可以出现的地点</li></ol><h3 id="传感器"><a href="#传感器" class="headerlink" title="传感器"></a>传感器</h3><p>有camera，分别提供RGB图像、深度信息和语义分割信息，语义包含物体的12种分类。</p><p>提供有GPS坐标、朝向、速度、加速度、碰撞等数据，以及交通规则评估数据，如行驶过的轨迹占据错误的车道比例等，还提供了所有动态物体的准确位置和bounding boxes。</p><h3 id="经典模块"><a href="#经典模块" class="headerlink" title="经典模块"></a>经典模块</h3><p>局部规划仅依赖于感知探测到的环境。内部有状态机：车道跟随，左转，右转，路口前进，停车。</p><p>PID控制。</p><p>感知的语义分割基于RefineNet。</p><p>判断是否处于路口基于AlexNet二分类。</p><h3 id="特性"><a href="#特性" class="headerlink" title="特性"></a>特性</h3><ul><li>通过服务器多客户端体系结构实现的可扩展性：同一节点或不同节点中的多个客户端可以控制不同的参与者。</li><li>灵活的 API: CARLA 提供了一个强大的 API，允许用户控制与模拟相关的所有方面，包括交通生成、行人行为、天气、传感器等等。</li><li>自动驾驶传感器套件支持：用户可以配置各种传感器套件，包括激光雷达、多摄像头、深度传感器和GPS等。</li><li>规划和控制的快速模拟：此模式禁用渲染，以提供不需要图形的交通模拟和道路行为的快速执行。</li><li>地图创建功能：用户可以通过 RoadRunner等工具轻松创建自己的符合 OpenDrive 标准的地图。</li><li>交通场景仿真：Engine ScenarioRunner 工具允许用户基于模块化行为定义和执行不同的交通场景的仿真。</li><li>ROS集成：通过我们的ROS-bridge为Carla提供与ROS的集成。</li><li>自动驾驶基线 baselines：我们在CARLA中提供作为可运行代理的自动驾驶基线，包括AutoWare代理和条件模仿学习代理。</li></ul><h3 id="安装"><a href="#安装" class="headerlink" title="安装"></a>安装</h3><p>ubuntu环境下安装流程：</p><p>基础软件安装</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">sudo apt-get update &amp;&amp;</span><br><span class="line">sudo apt-get install wget software-properties-common &amp;&amp;</span><br><span class="line">sudo add-apt-repository ppa:ubuntu-toolchain-r/test &amp;&amp;</span><br><span class="line">wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - &amp;&amp;</span><br><span class="line">sudo apt-add-repository &quot;deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-8 main&quot; &amp;&amp;</span><br><span class="line">sudo apt-get update</span><br></pre></td></tr></table></figure><p>Ubuntu 20.04:</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">sudo apt-add-repository &quot;deb http://apt.llvm.org/focal/ llvm-toolchain-focal main&quot;</span><br><span class="line">sudo apt-get install build-essential clang-10 lld-10 g++-7 cmake ninja-build libvulkan1 python python-dev python3-dev python3-pip libpng-dev libtiff5-dev libjpeg-dev tzdata sed curl unzip autoconf libtool rsync libxml2-dev git</span><br><span class="line">sudo update-alternatives --install /usr/bin/clang++ clang++ /usr/lib/llvm-10/bin/clang++ 180 &amp;&amp;</span><br><span class="line">sudo update-alternatives --install /usr/bin/clang clang /usr/lib/llvm-10/bin/clang 180</span><br></pre></td></tr></table></figure><p>安装Unreal Engine：</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">git clone --depth 1 -b carla https://github.com/CarlaUnreal/UnrealEngine.git ~/UnrealEngine_4.26</span><br><span class="line">cd ~/UnrealEngine_4.26</span><br><span class="line">./Setup.sh &amp;&amp; ./GenerateProjectFiles.sh &amp;&amp; make</span><br><span class="line">cd ~/UnrealEngine_4.26/Engine/Binaries/Linux &amp;&amp; ./UE4Editor</span><br><span class="line"></span><br></pre></td></tr></table></figure><p>编译Carla：</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">git clone https://github.com/carla-simulator/carla</span><br><span class="line">./Update.sh</span><br></pre></td></tr></table></figure><p>设置UE环境变量，直接写文件 gedit ~&#x2F;.bashrc </p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">export UE4_ROOT=~/UnrealEngine_4.26</span><br></pre></td></tr></table></figure><p>编译Carla客户端，可指定版本</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">make PythonAPI ARGS=&quot;--python-version=3.7, 3.8&quot;</span><br></pre></td></tr></table></figure><p>生成两种不同的文件 egg文件和whl文件，egg免安装，whl和系统有关</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">pip3 install &lt;path/to/wheel&gt;.whl </span><br></pre></td></tr></table></figure><p>编译Carla服务端，并启动：</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">make launch</span><br></pre></td></tr></table></figure><p>点击Play或执行测试脚本：</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">cd PythonAPI/examples</span><br><span class="line">python3 -m pip install -r requirements.txt</span><br><span class="line">python3 generate_traffic.py  </span><br><span class="line"><span class="meta prompt_"></span></span><br><span class="line"><span class="meta prompt_"># </span><span class="language-bash">Terminal B</span></span><br><span class="line">cd PythonAPI/examples</span><br><span class="line">python3 dynamic_weather.py </span><br></pre></td></tr></table></figure><h3 id="使用"><a href="#使用" class="headerlink" title="使用"></a>使用</h3><p>可以使用Python来实现一个客户端来跟Carla仿真环境进行交互，而在Carla Python库中，是以 carla.Client 类来实现的</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">client = carla.Client(<span class="string">&quot;localhost&quot;</span>, <span class="number">2000</span>)</span><br><span class="line">client.set_timeout(<span class="number">10.0</span>) </span><br></pre></td></tr></table></figure><p>client对象是跟Carla环境交互的唯一的入口，有了client对象以后，我们就可以获取到Carla环境中的世界（World）了，World对象确确实实的代表Carla环境中的世界，你想要在世界中创建任何东西，都是往这个World对象中添加</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">world = client.load_world(<span class="string">&#x27;Town02&#x27;</span>)</span><br></pre></td></tr></table></figure><p>客户端获取到了World对象，并使用了Town02这张由Carla官方提供的内置地图，Carla还提供了其他一些地图，我们可以通过更改load_world的参数来加载不同的地图.</p><p>有了world对象以后，我们来直观感受下，我们可以做什么样的交互，比如我们想控制世界的天气和时间（太阳的位置）。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">weather = carla.WeatherParameters(</span><br><span class="line">    cloudiness=<span class="number">0.0</span>,</span><br><span class="line">    precipitation=<span class="number">0.0</span>,</span><br><span class="line">    sun_altitude_angle=<span class="number">50.0</span>)</span><br><span class="line">world.set_weather(weather)</span><br></pre></td></tr></table></figure><p>使用carla.WeatherParameters创建了一种天气，万里无云，没有降雨，太阳的角度为50，也就是说是一个大晴天，然后通过世界的set_weather方法修改了世界的天气。</p><p>Actors：</p><blockquote><p>Actors我们可以称之为演员，在Carla世界中任何可以通过客户端创建的物体都称为Actors（演员），包括：车辆，行人，传感器等等</p></blockquote><p>Blueprints：</p><blockquote><p>创建一个Actor需要知道这些信息，也就是类似于需要使用一个模板，这个东西在Carla中叫做蓝图，即Blueprints</p></blockquote><p>Carla中已经内置了一个蓝图库，里边包含了许多不同的Actors，比如我们可以创建一台特斯拉的Model 3，我们也可以创建一台BWM，或者雪铁龙等等。</p><p>创建一个白色的特斯拉蓝图（模板）：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">model3_bp = world.get_blueprint_library().find(<span class="string">&#x27;vehicle.tesla.model3&#x27;</span>) model3_bp.set_attribute(<span class="string">&#x27;color&#x27;</span>, <span class="string">&#x27;255,255,255&#x27;</span>) </span><br></pre></td></tr></table></figure><p>world.get_blueprint_library()可以获取到Carla内置的蓝图库，通过find方法找到model 3，我们也可以使用filter方法然后给定模糊匹配的方法获取到多个蓝图，比如，可以用world.get_blueprint_library().filter(‘vehicle.bmw.*’)获取到所有BMW的车型蓝图。每一个蓝图都可以使用set_attribute更改自身的属性，这里我们将颜色设置成了白色。</p><p>Actor的生命周期：</p><p>一个Actor可以被生成（Spawning），使用（Handling），以及销毁（Destruction）。</p><ul><li>生成：有了蓝图之后，意味着我们有了模板，然后就需要使用这个模板生成一个或者多个演员（Actor），这一过程叫做Spawning，因为Actor是在世界中存在的一个物体，因此在生成的时候，需要告诉环境它的出生点在哪里。</li><li>使用：当Actor生成以后，就可以通过客户端来控制它的一些行为，比如可以让车子跑起来，并控制它的油门和转向。</li><li>销毁：当 我们不在需要一个Actor的时候可以选择销毁它，这样可以释放仿真环境的资源，让环境运行更顺畅。</li></ul><p>使用world对象的spawn_actor或者try_spawn_actor方法就可以生成一个Actor，不过上边提到，在生成Actor的时候，需要告诉世界它的出生点在哪里，出生点在Carla中抽象为carla.Transform，</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">transform = Transform(Location(x=<span class="number">230</span>, y=<span class="number">195</span>, z=<span class="number">40</span>), Rotation(yaw=<span class="number">180</span>)) actor = world.spawn_actor(blueprint, transform) </span><br></pre></td></tr></table></figure><p>物体在世界中<strong>是有体积的</strong>，所以有可能你指定的地点上已经有其他物体了，比如，该位置已经有台车子，或者该位置是一个建筑物，这种情况下，Actor被生成的时候就会<strong>出现碰撞</strong>，为了避免这个问题，Carla提供了一个接口可以获取到所有空的出生点，只要在这些空的出生点上生成Actor就不会出现碰撞的问题了</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">spawn_points = world.get_map().get_spawn_points() model3_spawn_point = np.random.choice(spawn_points) model3 = world.spawn_actor(model3_bp, model3_spawn_point) </span><br></pre></td></tr></table></figure><p>Actor被生成以后，生成方法会范围被生成的Actor对象，我们可以使用该对象控制Actor的行为。比如，我们可以让刚刚生成的Model 3动起来，使用下边的代码，让它按照交通规则在世界中行驶。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">model3.set_autopilot(<span class="literal">True</span>)</span><br></pre></td></tr></table></figure><p>还可以主动的把车挪动位置</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">location = model3.get_location() </span><br><span class="line">location.x += <span class="number">10.0</span> </span><br><span class="line">model3.set_location(location) </span><br></pre></td></tr></table></figure><p>不同的Actor类型，可以控制的参数也不一样</p><p>不在需要某个Actor的时候，可以将其销毁，释放计算资源</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">destroyed_sucessfully = actor.destroy()</span><br></pre></td></tr></table></figure><p>传感器 Sensors：</p><p>Sensor是一种特殊的Actor，它的蓝图也是可以在蓝图库里边找到的，目前Carla已经支持了很多传感器，比如</p><ul><li>摄像头： Depth， RGB ， Semantic segmentation</li><li>探测器： Collision ， Lane invasion ， Obstacle</li><li>其他： GNSS ， IMU ， LIDAR raycast ， Radar</li></ul><p>传感器跟其他的Actor最大的不同是，它们需要被安装在车上，因此在生成传感器的时候，需要将其附着到一个车辆类型的Actor上，而出生点是针对于这台车本身的坐标系给定的</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">camera_bp = world.get_blueprint_library().find(<span class="string">&#x27;sensor.camera.rgb&#x27;</span>) </span><br><span class="line">camera = world.spawn_actor(camera_bp,                         carla.Transform(carla.Location(x=-<span class="number">5.5</span>, z=<span class="number">2.5</span>), carla.Rotation(pitch=<span class="number">8.0</span>)), model3, carla.AttachmentType.SpringArm ) </span><br><span class="line">camera.listen(<span class="keyword">lambda</span> image:image.save_to_disk(<span class="string">&#x27;output/%06d.png&#x27;</span> % image.frame) </span><br></pre></td></tr></table></figure><p>首先从蓝图库中找到RGB摄像头模板，然后利用这个蓝图生成摄像头Actor，并将其附着到前边生成好的Model 3上，我们选择了摄像头附着类型为：carla.AttachmentType.SpringArm，并将其位置设置到后方，这样我们就可以像从一个第三者的角度排到行驶的车辆了。</p><p>每一个传感器都有一个listen方法，该方法接收一个callback作为参数，我们可以自定义callback里边的逻辑，callback将会在传感器拿到数据后被调用，并能够获取到这些数据。</p><p>观察者 spectator：</p><p>使用Python脚本生成了一台Model3 ，并让它行驶在路上，我们的Python脚本并没有能力输出视频，我们怎么来确认我们的 Model 3已经被创建，并在路上行驶呢</p><blockquote><p>Carla给我们提供了一个所谓的观察者 Spectator，你可以将其理解为Carla环境的视角，我们可以通过修改观察者的参数，切换Carla环境的视角。</p></blockquote><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">while</span> running:</span><br><span class="line">    spectator = world.get_spectator()</span><br><span class="line">    transform = model3.get_transform()</span><br><span class="line">    spectator.set_transform(carla.Transform(transform.location + carla.Location(z=<span class="number">50</span>),</span><br><span class="line">    carla.Rotation(pitch=-<span class="number">90</span>)))</span><br><span class="line">    time.sleep(<span class="number">5</span>)</span><br></pre></td></tr></table></figure><p>上边这段代码，将观察者的位置到我们的Model 3的正上方（z轴），并将视角调为-90度，即向下看，每隔5s中跟着Model 3的位置重置一下观察者的位置，这样就可以在Model 3的正上方追踪并观察这台车了。</p><p>这个项目最后会得到两个输出</p><ul><li>我们设置了观察值来实时追踪Model 3的行驶，因此在Carla环境中可以看到车子在行驶</li><li>我们在摄像头的callback中添加了保存摄像头采集到的照片到output文件夹的动作，因此我们会在output文件夹中看到很多摄像头采集回来的照片</li></ul><p><img src="https://img-blog.csdnimg.cn/e5626a7a7276485b898db27c74facba9.png" alt="在这里插入图片描述"></p>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;背景&quot;&gt;&lt;a href=&quot;#背景&quot; class=&quot;headerlink&quot; title=&quot;背景&quot;&gt;&lt;/a&gt;背景&lt;/h3&gt;&lt;p&gt;自动驾驶算法的调试和效果评测首先要在仿真环境中去做，因此，一个强大、灵活的仿真环境是开发、测试过程中必不可少的要素。我在查找可用的仿真工具时主要关注以下几个特性：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;开源，免费&lt;/li&gt;
&lt;li&gt;包含高速场景&lt;/li&gt;
&lt;li&gt;可以便捷的控制、切换场景的环境，场景、环境尽可能的丰富与真实&lt;/li&gt;
&lt;li&gt;可以便捷的控制、切换场景内移动物体（如车辆、行人等）的行为模式，行为模式尽可能的丰富与真实&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;CARLA是一个主要由英特尔实验室和巴塞罗那的计算机视觉中心开发的开源项目。包含了自动驾驶系统的3种方法：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;经典模块化方法，包含基于视觉的感知模块，基于规则的规划器，还有行为控制器。&lt;/li&gt;
&lt;li&gt;端到端的&lt;strong&gt;模仿学习&lt;/strong&gt;方法。&lt;/li&gt;
&lt;li&gt;端到端的&lt;strong&gt;强化学习&lt;/strong&gt;方法。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;特点：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;针对3D的城市场景（urban driving），开源免费，支持感知、规划、控制。&lt;/li&gt;
&lt;li&gt;基于Unreal Engine 4，server-client结构。&lt;/li&gt;
&lt;li&gt;可安装于Linux和Windows。&lt;/li&gt;
&lt;li&gt;Python API，没有C++ API。&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id=&quot;架构&quot;&gt;&lt;a href=&quot;#架构&quot; class=&quot;headerlink&quot; title=&quot;架构&quot;&gt;&lt;/a&gt;架构&lt;/h3&gt;&lt;p&gt;Carla 是一个开源的模拟器，可以模拟真实的交通环境，行人行为，汽车传感器信号等等。如下图所示，模拟器使用C++ 和&lt;strong&gt;虚幻 (Unreal) 引擎&lt;/strong&gt;构成，使用者可以通过 Python API 使用 Python 脚本代码对模拟器的环境进行操作和控制。&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/b80d4acc1aec462a9f21b8a7c76edae6.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/2297823926.html"/>
    <id>https://blog.aivgg.com/posts/2297823926.html</id>
    <published>2026-06-10T16:18:36.437Z</published>
    <updated>2026-06-10T16:24:04.381Z</updated>
    
    <content type="html"><![CDATA[<p>项目</p><p>虚幻引擎项目（Project） 保存着构成游戏所需的所有内容和代码。项目在你的电脑硬盘上由许多目录构成，例如蓝图和材质。你可以随时修改项目目录的名称和层级关系。</p><p>虚幻编辑器中的内容浏览器所展示的目录结构和你在硬盘上看到的项目目录结构相同。</p><p>每个项目都有一个与之对应的 .uproject 文件。.uproject 文件是你创建、打开或保存项目必须用到的文件。你可以创建任何数量的不同项目，并同时操作它们。</p><h3 id="对象"><a href="#对象" class="headerlink" title="对象"></a>对象</h3><p>在虚幻引擎中，最基本的类叫做 Object。换句话说，它就像最基本的构建单位，包含了资产的基本功能。虚幻引擎中的大多数类都继承自Object（或从中获取部分功能）。</p><p>在C++中，UObject 是所有Object的基类，包含各类功能，诸如垃圾回收、通过元数据（UProperty）将变量公开给编辑器，以及保存和加载时的序列化功能。</p><h3 id="类"><a href="#类" class="headerlink" title="类"></a>类</h3><p>类（Class） 用于定义虚幻引擎中Actor或对象的行为和属性。类可以被继承，这意味着某个类可以从其父类（衍生或派生出该类的类）获得信息，然后再将信息传递给子类。类可用C++代码或蓝图创建。</p><h3 id="蓝图"><a href="#蓝图" class="headerlink" title="蓝图"></a>蓝图</h3><p>蓝图可视化脚本（Blueprint Visual Scripting） 系统（或缩写 蓝图（Blueprints））是一种功能齐全的游戏脚本系统，它允许你在虚幻编辑器（Unreal Editor）中通过基于节点的界面来创建游戏元素。和许多常见脚本语言一样，你可以用它在引擎中定义面向对象的类或object。在使用UE4时，你会发现使用蓝图定义的类一般也统称蓝图。</p><h3 id="Actor"><a href="#Actor" class="headerlink" title="Actor"></a>Actor</h3><p>所有可以放入关卡的对象都是 Actor，比如摄像机、静态网格体、玩家起始位置。Actor支持三维变换，例如平移、旋转和缩放。你可以通过游戏逻辑代码（C++或蓝图）创建（生成）或销毁Actor。</p><p>个人理解：关卡中各种抽象出来的组件（不限于实体，比如”玩家起始位置”）的基类</p><h3 id="Pawn"><a href="#Pawn" class="headerlink" title="Pawn"></a>Pawn</h3><p>Pawn是Actor的子类，它可以充当游戏中的化身或人物（例如游戏中的角色）。Pawn可以由玩家控制，也可以由游戏AI控制并以非玩家角色（NPC）的形式存在于游戏中。</p><p>当Pawn被人类玩家或AI玩家控制时，它被视为已被控制（Possessed）。相反，当Pawn未被人类玩家或AI玩家控制时，它被视为未被控制（Unpossessed）。</p><h3 id="角色"><a href="#角色" class="headerlink" title="角色"></a>角色</h3><p>角色（Character） 是Pawn Actor的子类，旨在用作玩家角色。角色子类包括碰撞设置、双足运动的输入绑定，以及用于控制运动的附加代码。</p><h3 id="组件"><a href="#组件" class="headerlink" title="组件"></a>组件</h3><p>组件（Component） 是可以添加到Actor上的一项功能。</p><p>当你为Actor添加组件后，该Actor便获得了该组件所提供的功能。例如：</p><p>聚光灯组件（Spot Light Component）允许你的Actor像聚光灯一样发光，<br>旋转移动组件（Rotating Movement Component）能使你的Actor四处旋转，<br>音频组件（Audio Component）将使你的Actor能够播放声音。</p><p>组件必须绑定在Actor身上，它们无法单独存在。</p><h3 id="世界"><a href="#世界" class="headerlink" title="世界"></a>世界</h3><p>世界场景（World） 是一个容器，包含了游戏中的所有关卡。它可以处理关卡流送，还能生成（创建）动态Actor。</p><h3 id="玩家控制器"><a href="#玩家控制器" class="headerlink" title="玩家控制器"></a>玩家控制器</h3><p>玩家控制器（Player Controller） 会获取游戏中玩家的输入信息，然后转换为交互效果，每个游戏中至少有一个玩家控制器。玩家控制器通常会控制一个Pawn或角色，将其作为玩家在游戏中的化身。</p><p>相关的C++类是 PlayerController。</p><h3 id="AI控制器"><a href="#AI控制器" class="headerlink" title="AI控制器"></a>AI控制器</h3><p>玩家控制器通过控制Pawn来表示游戏中的玩家，与此类似，AI控制器 通过控制Pawn来表示游戏中的非玩家角色（NPC）。默认情况下，Pawn和角色最终都会由基本的AI控制器控制，除非它们被指定通过玩家控制器控制，或被告知不需要为它们自己创建AI控制器。</p><p>关联的C++类是 AIController</p><p><a href="https://zhuanlan.zhihu.com/p/27448628">https://zhuanlan.zhihu.com/p/27448628</a></p><p><a href="https://zhuanlan.zhihu.com/p/535829374">https://zhuanlan.zhihu.com/p/535829374</a></p><p><a href="https://blog.csdn.net/brzzuibang/article/details/105823494">https://blog.csdn.net/brzzuibang/article/details/105823494</a></p><p><a href="https://docs.unrealengine.com/4.27/zh-CN/InteractiveExperiences/Vehicles/VehicleUserGuide/">https://docs.unrealengine.com/4.27/zh-CN/InteractiveExperiences/Vehicles/VehicleUserGuide/</a></p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line">transform = carla.Transform(transform.location + carla.Location(0, 0, SPAWN_OFFSET_Z),</span><br><span class="line">                                    transform.rotation)</span><br><span class="line">        batch = []</span><br><span class="line">        actor = carla.command.SpawnActor(blueprint, transform)</span><br><span class="line">        try:</span><br><span class="line">            if number_of_wheels == 4:</span><br><span class="line">                # 车</span><br><span class="line">                control = carla.VehicleControl()</span><br><span class="line">                control.throttle = 23</span><br><span class="line">                print(&quot;11111&quot;)</span><br><span class="line">                print(type(actor))</span><br><span class="line">                actor.apply_control(control)</span><br><span class="line">        except Exception as e:</span><br><span class="line">            print(e)</span><br><span class="line">        batch.append(actor)</span><br><span class="line">        # blueprint.</span><br><span class="line">        batch.append(carla.command.SetSimulatePhysics(carla.command.FutureActor, False))</span><br><span class="line"></span><br><span class="line">        response = self.client.apply_batch_sync(batch, False)[0]</span><br><span class="line">        if response.error:</span><br><span class="line">            logging.error(&#x27;Spawn carla actor failed. %s&#x27;, response.error)</span><br><span class="line">            return INVALID_ACTOR_ID</span><br><span class="line">        #</span><br></pre></td></tr></table></figure><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">if number_of_wheels == 4:</span><br><span class="line">    # 车</span><br><span class="line">    control = carla.VehicleControl()</span><br><span class="line">    control.throttle = 23</span><br><span class="line">    vehicle.apply_control(control)</span><br><span class="line">elif number_of_wheels == 2:</span><br><span class="line">    control = carla.VehicleControl()</span><br><span class="line">    control.throttle = 2</span><br><span class="line">    vehicle.apply_control(control)</span><br></pre></td></tr></table></figure><figure class="highlight c++"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br></pre></td><td class="code"><pre><span class="line"><span class="function">GetAnimInstance</span></span><br><span class="line"><span class="function"><span class="title">check</span><span class="params">(VehicleAnim-&gt;GetWheeledVehicleMovementComponent() != <span class="literal">nullptr</span>)</span></span></span><br><span class="line"><span class="function"></span></span><br><span class="line"><span class="function">  <span class="type">void</span> <span class="title">ACarlaWheeledVehicle::SetSimulatePhysics</span><span class="params">(<span class="type">bool</span> enabled)</span> </span>&#123;</span><br><span class="line">  <span class="keyword">if</span>(!<span class="built_in">GetCarlaMovementComponent</span>&lt;UDefaultMovementComponent&gt;())</span><br><span class="line">  &#123;</span><br><span class="line">    <span class="keyword">return</span>;</span><br><span class="line">  &#125;</span><br><span class="line"></span><br><span class="line">  UWheeledVehicleMovementComponent4W *Vehicle4W = <span class="built_in">Cast</span>&lt;UWheeledVehicleMovementComponent4W&gt;(</span><br><span class="line">      <span class="built_in">GetVehicleMovement</span>());</span><br><span class="line">  <span class="built_in">check</span>(Vehicle4W != <span class="literal">nullptr</span>);</span><br><span class="line"></span><br><span class="line">  <span class="keyword">if</span>(bPhysicsEnabled == enabled)</span><br><span class="line">    <span class="keyword">return</span>;</span><br><span class="line"></span><br><span class="line">  <span class="built_in">SetActorEnableCollision</span>(<span class="literal">true</span>);</span><br><span class="line">  <span class="keyword">auto</span> RootComponent = <span class="built_in">Cast</span>&lt;UPrimitiveComponent&gt;(<span class="built_in">GetRootComponent</span>());</span><br><span class="line">  RootComponent-&gt;<span class="built_in">SetSimulatePhysics</span>(enabled);</span><br><span class="line">  RootComponent-&gt;<span class="built_in">SetCollisionEnabled</span>(ECollisionEnabled::QueryAndPhysics);</span><br><span class="line"></span><br><span class="line">  UVehicleAnimInstance *VehicleAnim = <span class="built_in">Cast</span>&lt;UVehicleAnimInstance&gt;(<span class="built_in">GetMesh</span>()-&gt;<span class="built_in">GetAnimInstance</span>());</span><br><span class="line">  <span class="built_in">check</span>(VehicleAnim != <span class="literal">nullptr</span>)</span><br><span class="line"></span><br><span class="line">  <span class="built_in">GetWorld</span>()-&gt;<span class="built_in">GetPhysicsScene</span>()-&gt;<span class="built_in">GetPxScene</span>()-&gt;<span class="built_in">lockWrite</span>();</span><br><span class="line">  <span class="keyword">if</span> (enabled)</span><br><span class="line">  &#123;</span><br><span class="line">    Vehicle4W-&gt;<span class="built_in">RecreatePhysicsState</span>();</span><br><span class="line">    VehicleAnim-&gt;<span class="built_in">ResetWheelCustomRotations</span>();</span><br><span class="line">  &#125;</span><br><span class="line">  <span class="keyword">else</span></span><br><span class="line">  &#123;</span><br><span class="line">    Vehicle4W-&gt;<span class="built_in">DestroyPhysicsState</span>();</span><br><span class="line">  &#125;</span><br><span class="line"></span><br><span class="line">  <span class="built_in">GetWorld</span>()-&gt;<span class="built_in">GetPhysicsScene</span>()-&gt;<span class="built_in">GetPxScene</span>()-&gt;<span class="built_in">unlockWrite</span>();</span><br><span class="line"></span><br><span class="line">  bPhysicsEnabled = enabled;</span><br><span class="line"></span><br><span class="line">  <span class="built_in">ResetConstraints</span>();</span><br><span class="line"></span><br><span class="line">&#125;</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">VehicleAnimInstance</span><br></pre></td></tr></table></figure><p>carla里面新增一辆车</p><p><a href="https://carla.readthedocs.io/en/latest/tuto_A_add_vehicle/">https://carla.readthedocs.io/en/latest/tuto_A_add_vehicle/</a></p><p><a href="https://docs.unrealengine.com/4.27/zh-CN/InteractiveExperiences/Vehicles/VehicleUserGuide/">https://docs.unrealengine.com/4.27/zh-CN/InteractiveExperiences/Vehicles/VehicleUserGuide/</a></p><p><img src="/home/tony/.config/Typora/typora-user-images/image-20230306101624753.png" alt="image-20230306101624753"></p><p><a href="https://blog.csdn.net/qq_44905590/article/details/103034017">https://blog.csdn.net/qq_44905590/article/details/103034017</a></p>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;项目&lt;/p&gt;
&lt;p&gt;虚幻引擎项目（Project） 保存着构成游戏所需的所有内容和代码。项目在你的电脑硬盘上由许多目录构成，例如蓝图和材质。你可以随时修改项目目录的名称和层级关系。&lt;/p&gt;
&lt;p&gt;虚幻编辑器中的内容浏览器所展示的目录结构和你在硬盘上看到的项目目录结构相同。&lt;/p&gt;
&lt;p&gt;每个项目都有一个与之对应的 .uproject 文件。.uproject 文件是你创建、打开或保存项目必须用到的文件。你可以创建任何数量的不同项目，并同时操作它们。&lt;/p&gt;
&lt;h3 id=&quot;对象&quot;&gt;&lt;a href=&quot;#对象&quot; class=&quot;headerlink&quot; title=&quot;对象&quot;&gt;&lt;/a&gt;对象&lt;/h3&gt;&lt;p&gt;在虚幻引擎中，最基本的类叫做 Object。换句话说，它就像最基本的构建单位，包含了资产的基本功能。虚幻引擎中的大多数类都继承自Object（或从中获取部分功能）。&lt;/p&gt;
&lt;p&gt;在C++中，UObject 是所有Object的基类，包含各类功能，诸如垃圾回收、通过元数据（UProperty）将变量公开给编辑器，以及保存和加载时的序列化功能。&lt;/p&gt;
&lt;h3 id=&quot;类&quot;&gt;&lt;a href=&quot;#类&quot; class=&quot;headerlink&quot; title=&quot;类&quot;&gt;&lt;/a&gt;类&lt;/h3&gt;&lt;p&gt;类（Class） 用于定义虚幻引擎中Actor或对象的行为和属性。类可以被继承，这意味着某个类可以从其父类（衍生或派生出该类的类）获得信息，然后再将信息传递给子类。类可用C++代码或蓝图创建。&lt;/p&gt;
&lt;h3 id=&quot;蓝图&quot;&gt;&lt;a href=&quot;#蓝图&quot; class=&quot;headerlink&quot; title=&quot;蓝图&quot;&gt;&lt;/a&gt;蓝图&lt;/h3&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/648375178.html"/>
    <id>https://blog.aivgg.com/posts/648375178.html</id>
    <published>2026-06-10T16:18:36.419Z</published>
    <updated>2026-06-10T16:24:04.380Z</updated>
    
    <content type="html"><![CDATA[<p>SUMO ( Simulation of Urban Mobility) 是免费、开源的<strong>交通系统仿真软件</strong>，可以实现交通流的微观控制，即具体到道路上每一辆车的运行路线都可以单独规划。可模拟复杂环境中的交通流。</p><p> sumo中一个路网文件，分为路网net文件和交通需求（路径）route文件。net文件由node文件和edge文件组成。其中node表示节点，如一个交叉口。</p><p><img src="https://img-blog.csdnimg.cn/60231f738f734b0398c3be55f291d444.png" alt="在这里插入图片描述"></p><blockquote><ol><li>节点文件 node file (.nod.xml)</li><li>连边文件 edge file (.edg.xml)</li><li>类型文件 edge type file (.type.xml)</li><li>基于上述三个文件创建路网文件 net file (.net.xml)</li><li>路由文件 route file (.rou.xml)</li></ol></blockquote><p>上述文件本质上都是 xml 文件，不过为了方便区分其作用，额外增加了一个后缀名。</p><p>假设我们要创建如下图所示的小型道路网络：</p><p><img src="https://img-blog.csdnimg.cn/61d22eb0edea46178083dc319f043d56.png" alt="在这里插入图片描述"></p><p>图中黑色节点对应交通路口，连边对应道路。每个路口所在位置坐标已给出。</p><p>node file</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="tag">&lt;<span class="name">nodes</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">node</span> <span class="attr">id</span>=<span class="string">&quot;n1&quot;</span> <span class="attr">x</span>=<span class="string">&quot;-500&quot;</span> <span class="attr">y</span>=<span class="string">&quot;0&quot;</span> <span class="attr">type</span>=<span class="string">&quot;priority&quot;</span>/&gt;</span>   </span><br><span class="line"> <span class="tag">&lt;<span class="name">node</span> <span class="attr">id</span>=<span class="string">&quot;n2&quot;</span> <span class="attr">x</span>=<span class="string">&quot;-250&quot;</span> <span class="attr">y</span>=<span class="string">&quot;0&quot;</span> <span class="attr">type</span>=<span class="string">&quot;traffic_light&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">node</span> <span class="attr">id</span>=<span class="string">&quot;n3&quot;</span> <span class="attr">x</span>=<span class="string">&quot;-150&quot;</span> <span class="attr">y</span>=<span class="string">&quot;200&quot;</span> <span class="attr">type</span>=<span class="string">&quot;traffic_light&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">node</span> <span class="attr">id</span>=<span class="string">&quot;n4&quot;</span> <span class="attr">x</span>=<span class="string">&quot;0&quot;</span> <span class="attr">y</span>=<span class="string">&quot;0&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">node</span> <span class="attr">id</span>=<span class="string">&quot;n5&quot;</span> <span class="attr">x</span>=<span class="string">&quot;150&quot;</span> <span class="attr">y</span>=<span class="string">&quot;200&quot;</span>/&gt;</span></span><br><span class="line"><span class="tag">&lt;/<span class="name">nodes</span>&gt;</span></span><br></pre></td></tr></table></figure><p>edge file</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="tag">&lt;<span class="name">edges</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">edge</span> <span class="attr">from</span>=<span class="string">&quot;n1&quot;</span> <span class="attr">to</span>=<span class="string">&quot;n2&quot;</span> <span class="attr">id</span>=<span class="string">&quot;1to2&quot;</span> <span class="attr">type</span>=<span class="string">&quot;3L45&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">edge</span> <span class="attr">from</span>=<span class="string">&quot;n2&quot;</span> <span class="attr">to</span>=<span class="string">&quot;n3&quot;</span> <span class="attr">id</span>=<span class="string">&quot;2to3&quot;</span> <span class="attr">type</span>=<span class="string">&quot;2L15&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">edge</span> <span class="attr">from</span>=<span class="string">&quot;n3&quot;</span> <span class="attr">to</span>=<span class="string">&quot;n4&quot;</span> <span class="attr">id</span>=<span class="string">&quot;3to4&quot;</span> <span class="attr">type</span>=<span class="string">&quot;3L30&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">edge</span> <span class="attr">from</span>=<span class="string">&quot;n4&quot;</span> <span class="attr">to</span>=<span class="string">&quot;n5&quot;</span> <span class="attr">id</span>=<span class="string">&quot;out&quot;</span> <span class="attr">type</span>=<span class="string">&quot;3L30&quot;</span>/&gt;</span></span><br><span class="line"><span class="tag">&lt;/<span class="name">edges</span>&gt;</span></span><br></pre></td></tr></table></figure><p>type file</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="tag">&lt;<span class="name">types</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">type</span> <span class="attr">id</span>=<span class="string">&quot;3L45&quot;</span> <span class="attr">priority</span>=<span class="string">&quot;3&quot;</span> <span class="attr">numLanes</span>=<span class="string">&quot;3&quot;</span> <span class="attr">speed</span>=<span class="string">&quot;45&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">type</span> <span class="attr">id</span>=<span class="string">&quot;2L15&quot;</span> <span class="attr">priority</span>=<span class="string">&quot;3&quot;</span> <span class="attr">numLanes</span>=<span class="string">&quot;2&quot;</span> <span class="attr">speed</span>=<span class="string">&quot;15&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">type</span> <span class="attr">id</span>=<span class="string">&quot;3L30&quot;</span> <span class="attr">priority</span>=<span class="string">&quot;2&quot;</span> <span class="attr">numLanes</span>=<span class="string">&quot;3&quot;</span> <span class="attr">speed</span>=<span class="string">&quot;30&quot;</span>/&gt;</span></span><br><span class="line"><span class="tag">&lt;/<span class="name">types</span>&gt;</span></span><br></pre></td></tr></table></figure><p>基于以上三个文件，可以通过命令 netconvert 创建 net 文件，命令如下：</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">netconvert --node-files my_nodes.nod.xml --edge-files my_edge.edg.xml -t my_type.type.xml -o my_net.net.xml</span><br></pre></td></tr></table></figure><p>route file</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="tag">&lt;<span class="name">routes</span>&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">route</span> <span class="attr">id</span>=<span class="string">&quot;route0&quot;</span> <span class="attr">edges</span>=<span class="string">&quot;1to2 2to3&quot;</span>/&gt;</span>  # edges 中的基本格式为&quot;edge1 edge2 edge3 ...&quot;</span><br><span class="line">   <span class="tag">&lt;<span class="name">route</span> <span class="attr">id</span>=<span class="string">&quot;route1&quot;</span> <span class="attr">edges</span>=<span class="string">&quot;2to3 3to4&quot;</span>/&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">route</span> <span class="attr">id</span>=<span class="string">&quot;route2&quot;</span> <span class="attr">edges</span>=<span class="string">&quot;3to4 out&quot;</span>/&gt;</span></span><br><span class="line"></span><br><span class="line">   <span class="tag">&lt;<span class="name">vType</span> <span class="attr">accel</span>=<span class="string">&quot;1.0&quot;</span> <span class="attr">decel</span>=<span class="string">&quot;5.0&quot;</span> <span class="attr">id</span>=<span class="string">&quot;Car&quot;</span> <span class="attr">length</span>=<span class="string">&quot;2.0&quot;</span> <span class="attr">maxSpeed</span>=<span class="string">&quot;100.0&quot;</span> <span class="attr">sigma</span>=<span class="string">&quot;0.0&quot;</span>/&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">vType</span> <span class="attr">accel</span>=<span class="string">&quot;1.0&quot;</span> <span class="attr">decel</span>=<span class="string">&quot;5.0&quot;</span> <span class="attr">id</span>=<span class="string">&quot;Bus&quot;</span> <span class="attr">length</span>=<span class="string">&quot;12.0&quot;</span> <span class="attr">maxSpeed</span>=<span class="string">&quot;1.0&quot;</span> <span class="attr">sigma</span>=<span class="string">&quot;0.0&quot;</span>/&gt;</span> #sigma随机程度，0 为无随机</span><br><span class="line"></span><br><span class="line">   <span class="tag">&lt;<span class="name">vehicle</span> <span class="attr">id</span>=<span class="string">&quot;veh0&quot;</span> <span class="attr">depart</span>=<span class="string">&quot;10&quot;</span> <span class="attr">route</span>=<span class="string">&quot;route0&quot;</span> <span class="attr">type</span>=<span class="string">&quot;Bus&quot;</span>/&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">vehicle</span> <span class="attr">id</span>=<span class="string">&quot;veh1&quot;</span> <span class="attr">depart</span>=<span class="string">&quot;10&quot;</span> <span class="attr">route</span>=<span class="string">&quot;route1&quot;</span> <span class="attr">type</span>=<span class="string">&quot;Car&quot;</span>/&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">vehicle</span> <span class="attr">id</span>=<span class="string">&quot;veh2&quot;</span> <span class="attr">depart</span>=<span class="string">&quot;30&quot;</span> <span class="attr">route</span>=<span class="string">&quot;route2&quot;</span> <span class="attr">type</span>=<span class="string">&quot;Car&quot;</span>/&gt;</span></span><br><span class="line"><span class="tag">&lt;/<span class="name">routes</span>&gt;</span></span><br></pre></td></tr></table></figure><p>运行程序时需要送入一些参数，可以通过命令行形式送入，如果参数太多、太长，为了方便起见，可以将参数统一放到 xml config 文件中，在运行时，可以调用这个 config 文件。</p><p>定义 my_config_file.sumocfg</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="tag">&lt;<span class="name">configuration</span>&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">input</span>&gt;</span></span><br><span class="line">     <span class="tag">&lt;<span class="name">net-file</span> <span class="attr">value</span>=<span class="string">&quot;my_net.net.xml&quot;</span>/&gt;</span></span><br><span class="line">     <span class="tag">&lt;<span class="name">route-files</span> <span class="attr">value</span>=<span class="string">&quot;my_route.rou.xml&quot;</span>/&gt;</span></span><br><span class="line">   <span class="tag">&lt;/<span class="name">input</span>&gt;</span></span><br><span class="line">   <span class="tag">&lt;<span class="name">time</span>&gt;</span></span><br><span class="line">     <span class="tag">&lt;<span class="name">begin</span> <span class="attr">value</span>=<span class="string">&quot;0&quot;</span>/&gt;</span></span><br><span class="line">     <span class="tag">&lt;<span class="name">end</span> <span class="attr">value</span>=<span class="string">&quot;2000&quot;</span>/&gt;</span></span><br><span class="line">   <span class="tag">&lt;/<span class="name">time</span>&gt;</span></span><br><span class="line"><span class="tag">&lt;/<span class="name">configuration</span>&gt;</span></span><br></pre></td></tr></table></figure><p>如果一个参数既出现在了 config 文件中，又在 command line 中，则采用 command line 的设置。</p><p>一切准备就绪，下边运行程序</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">sumo-gui my_config_file.sumocfg</span><br></pre></td></tr></table></figure><p>然后将工具栏中的 Delay 设置为 100 ms，否则仿真开始之后瞬间结束。</p><p><img src="https://img-blog.csdnimg.cn/dc11e9a366874a8697593be2a26dfd83.png" alt="在这里插入图片描述"></p><p>在手动构造路网 net.xml 文件时，我们也可以用 SUMO 自带的 NETEDIT 程序，<strong>通过 NETEDIT GUI 编辑路网</strong>，可能效率更高一些</p><p>上述手动设置路网的方式只适用于比较简单的情况，如果要构造与现实世界比较接近的大型路网，我们可以用下边的从外部导入 OSM （Open Street Map）路网的方法。通过搜索城市、街道找到目标道路网，然后 export 即可。</p><p><img src="https://img-blog.csdnimg.cn/faea5665c9234a408fc94c2affff19ed.png" alt="在这里插入图片描述"></p><p>转化成 SUMO 路网文件</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">netconvert --osm-files map.osm -o sjtu.net.xml</span><br></pre></td></tr></table></figure><p>以上就得到了 .net.xml 文件，这里不是通过基于 node, edge, type 文件的整合，而是直接从 osm 地图转化过来。下边就是如何得到 route 文件。</p><p>对于这种大型的路网，手动创建 route 文件也很麻烦，这里我们用 SUMO 自带的 randomTrips.py 程序创建<strong>随机的 route 文件</strong>。</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">python <span class="tag">&lt;<span class="name">path_to_randomTrips.py</span>&gt;</span> -n sjtu.net.xml -r sjtu.rou.xml -e 50 -l  # -e 表示 end time</span><br></pre></td></tr></table></figure><p>最后汇总sjtu.sumocfg</p><figure class="highlight xml"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="tag">&lt;<span class="name">configuration</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">input</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">net-file</span> <span class="attr">value</span>=<span class="string">&quot;sjtu.net.xml&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">route-files</span> <span class="attr">value</span>=<span class="string">&quot;sjtu.rou.xml&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;/<span class="name">input</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">time</span>&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">begin</span> <span class="attr">value</span>=<span class="string">&quot;0&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;<span class="name">end</span> <span class="attr">value</span>=<span class="string">&quot;2000&quot;</span>/&gt;</span></span><br><span class="line"> <span class="tag">&lt;/<span class="name">time</span>&gt;</span></span><br><span class="line"><span class="tag">&lt;/<span class="name">configuration</span>&gt;</span></span><br></pre></td></tr></table></figure><p>运行仿真，局部放大：</p><p><img src="https://img-blog.csdnimg.cn/d6689ba707cb4331b3662f9822cd77df.png" alt="在这里插入图片描述"></p><p>上边导入 osm 地图的方法还是比较麻烦，它主要包括 4 步：</p><ul><li>从 osm 网站获取 osm 地图</li><li>用 netconvert 将 osm 地图转化成 SUMO 的 .net.xml 格式地图</li><li>用 randomTrip.py 生成随机 route 文件</li><li>开启仿真</li></ul><p>实际上，SUMO 自带了一个 <strong>osmWebWizard</strong>.py 程序，整合了上述较为独立的步骤，在同一个操作界面，“一站式” 完成上述步骤。</p><p>用 osmWebWizard.py 运行仿真也是 SUMO tutorial 中的第一个项目。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">python osmWebWizard.py</span><br></pre></td></tr></table></figure><p>没有问题的话，应该会在浏览器中打开如下页面。这里初始地图位置是 Berlin。</p><p><img src="https://img-blog.csdnimg.cn/0c221a9f1cf040468abda439d772f74c.png" alt="在这里插入图片描述"></p><p>首先是选定要仿真的地图环境。可以缩放、移动视图，通过右侧的 Select Area 可以选定一个区域。最好不要选择太大范围，否则仿真很占资源，甚至导致死机。</p><p>以上就设定好了地图和 route，点击右上方的 Generate Scenario， 就可以进入仿真界面了。</p><h3 id="安装"><a href="#安装" class="headerlink" title="安装"></a>安装</h3><p>安装XQuartz ，启动sumo-gui和netedit需要</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">brew install --cask xquartz </span><br><span class="line"># 安装sumo</span><br><span class="line">brew tap dlr-ts/sumo</span><br><span class="line">brew install sumo</span><br><span class="line"># 更改变量环境</span><br><span class="line">touch ~/.bashrc; open ~/.bashrc</span><br><span class="line"># 在最后一行添加，其中安装路径会在安装后的终端显示。</span><br><span class="line">export SUMO_HOME=/your/path/to/sumo</span><br><span class="line"># 测试变量环境 重启终端，并输入</span><br><span class="line">echo $SUMO_HOME</span><br><span class="line"># 安装一些mac下的应用包</span><br><span class="line">brew install --cask sumo-gui</span><br><span class="line"># 在下载页面下载SUMO launchers</span><br><span class="line"># 终端启动XQuartz 或sumo-gui</span><br><span class="line"></span><br></pre></td></tr></table></figure><h3 id="Traci接口"><a href="#Traci接口" class="headerlink" title="Traci接口"></a>Traci接口</h3><p>Traci接口是用来和sumo模拟器通信的, 因为不可能总是在sumo-gui里点图形化界面, 肯定得通过python, java之类的语言来和sumo通信, 靠的就是traci接口。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> os</span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"><span class="keyword">import</span> traci</span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">init_sumo</span>(<span class="params">sumoBinary, sumocfg</span>):</span><br><span class="line">    <span class="keyword">if</span> <span class="string">&#x27;SUMO_HOME&#x27;</span> <span class="keyword">in</span> os.environ:</span><br><span class="line">        tools = os.path.join(os.environ[<span class="string">&#x27;SUMO_HOME&#x27;</span>], <span class="string">&#x27;tools&#x27;</span>)</span><br><span class="line">        sys.path.append(tools)</span><br><span class="line">        sumoCmd = [sumoBinary, <span class="string">&quot;-c&quot;</span>, sumocfg, <span class="string">&quot;--tripinfo-output&quot;</span>, <span class="string">&quot;tripinfo.xml&quot;</span>]</span><br><span class="line">        <span class="keyword">return</span> sumoCmd</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        sys.exit(<span class="string">&quot;please declare environment variable &#x27;SUMO_HOME&#x27;&quot;</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">generate_routefile</span>():</span><br><span class="line">    random.seed(<span class="number">42</span>)  <span class="comment"># make tests reproducible</span></span><br><span class="line">    N = <span class="number">3600</span>  <span class="comment"># number of time steps</span></span><br><span class="line">    <span class="comment"># demand per second from different directions</span></span><br><span class="line">    pWE = <span class="number">1.</span> / <span class="number">10</span></span><br><span class="line">    pEW = <span class="number">1.</span> / <span class="number">11</span></span><br><span class="line">    pNS = <span class="number">1.</span> / <span class="number">30</span></span><br><span class="line">    <span class="keyword">with</span> <span class="built_in">open</span>(<span class="string">&quot;data/cross.rou.xml&quot;</span>, <span class="string">&quot;w&quot;</span>) <span class="keyword">as</span> routes:</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;&quot;&quot;&lt;routes&gt;</span></span><br><span class="line"><span class="string">        &lt;vType id=&quot;typeWE&quot; accel=&quot;0.8&quot; decel=&quot;4.5&quot; sigma=&quot;0.5&quot; length=&quot;5&quot; minGap=&quot;2.5&quot; maxSpeed=&quot;16.67&quot; \</span></span><br><span class="line"><span class="string">guiShape=&quot;passenger&quot;/&gt;</span></span><br><span class="line"><span class="string">        &lt;vType id=&quot;typeNS&quot; accel=&quot;0.8&quot; decel=&quot;4.5&quot; sigma=&quot;0.5&quot; length=&quot;7&quot; minGap=&quot;3&quot; maxSpeed=&quot;25&quot; guiShape=&quot;bus&quot;/&gt;</span></span><br><span class="line"><span class="string"></span></span><br><span class="line"><span class="string">        &lt;route id=&quot;right&quot; edges=&quot;51o 1i 2o 52i&quot; /&gt;</span></span><br><span class="line"><span class="string">        &lt;route id=&quot;left&quot; edges=&quot;52o 2i 1o 51i&quot; /&gt;</span></span><br><span class="line"><span class="string">        &lt;route id=&quot;down&quot; edges=&quot;54o 4i 3o 53i&quot; /&gt;&quot;&quot;&quot;</span>, file=routes)</span><br><span class="line">        vehNr = <span class="number">0</span></span><br><span class="line">        <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(N):</span><br><span class="line">            <span class="keyword">if</span> random.uniform(<span class="number">0</span>, <span class="number">1</span>) &lt; pWE:</span><br><span class="line">                <span class="built_in">print</span>(<span class="string">&#x27;    &lt;vehicle id=&quot;right_%i&quot; type=&quot;typeWE&quot; route=&quot;right&quot; depart=&quot;%i&quot; /&gt;&#x27;</span> % (</span><br><span class="line">                    vehNr, i), file=routes)</span><br><span class="line">                vehNr += <span class="number">1</span></span><br><span class="line">            <span class="keyword">if</span> random.uniform(<span class="number">0</span>, <span class="number">1</span>) &lt; pEW:</span><br><span class="line">                <span class="built_in">print</span>(<span class="string">&#x27;    &lt;vehicle id=&quot;left_%i&quot; type=&quot;typeWE&quot; route=&quot;left&quot; depart=&quot;%i&quot; /&gt;&#x27;</span> % (</span><br><span class="line">                    vehNr, i), file=routes)</span><br><span class="line">                vehNr += <span class="number">1</span></span><br><span class="line">            <span class="keyword">if</span> random.uniform(<span class="number">0</span>, <span class="number">1</span>) &lt; pNS:</span><br><span class="line">                <span class="built_in">print</span>(<span class="string">&#x27;    &lt;vehicle id=&quot;down_%i&quot; type=&quot;typeNS&quot; route=&quot;down&quot; depart=&quot;%i&quot; color=&quot;1,0,0&quot;/&gt;&#x27;</span> % (</span><br><span class="line">                    vehNr, i), file=routes)</span><br><span class="line">                vehNr += <span class="number">1</span></span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;&lt;/routes&gt;&quot;</span>, file=routes)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">&#x27;__main__&#x27;</span>:</span><br><span class="line">    sumoCmd = init_sumo(<span class="string">&quot;sumo-gui&quot;</span>, os.getcwd()+<span class="string">&quot;/data/cross.sumocfg&quot;</span>)</span><br><span class="line">    generate_routefile()</span><br><span class="line"></span><br><span class="line">    traci.start(sumoCmd)</span><br><span class="line">    step = <span class="number">0</span></span><br><span class="line">    <span class="comment"># we start with phase 2 where EW has green</span></span><br><span class="line">    traci.trafficlight.setPhase(<span class="string">&quot;0&quot;</span>, <span class="number">2</span>)</span><br><span class="line">    <span class="keyword">while</span> traci.simulation.getMinExpectedNumber() &gt; <span class="number">0</span>:</span><br><span class="line">        traci.simulationStep()</span><br><span class="line">        <span class="keyword">if</span> traci.trafficlight.getPhase(<span class="string">&quot;0&quot;</span>) == <span class="number">2</span>:</span><br><span class="line">            <span class="comment"># we are not already switching</span></span><br><span class="line">            <span class="keyword">if</span> traci.inductionloop.getLastStepVehicleNumber(<span class="string">&quot;0&quot;</span>) &gt; <span class="number">0</span>:</span><br><span class="line">                <span class="comment"># there is a vehicle from the north, switch</span></span><br><span class="line">                traci.trafficlight.setPhase(<span class="string">&quot;0&quot;</span>, <span class="number">3</span>)</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                <span class="comment"># otherwise try to keep green for EW</span></span><br><span class="line">                traci.trafficlight.setPhase(<span class="string">&quot;0&quot;</span>, <span class="number">2</span>)</span><br><span class="line">        step += <span class="number">1</span></span><br><span class="line">    traci.close()</span><br></pre></td></tr></table></figure><p>相关接口说明：</p><blockquote><p>traci.trafficlight.setPhase 设置红绿灯的状态（tlsID, index） 第一个为交通灯ID，第二个为灯的状态，2为绿灯</p><p>getNextSwitch(string) 获取到下个信号灯相位的时间</p><p>getPhaseDuration(string) 获取该信号灯相位已经持续的时间</p><p>setPhaseDuration(string, double) 设置当前信号灯持续的时间</p><p>getMinExpectedNumber 系统仿真中车辆数。如果车辆数为0，说明所有车辆已经离开路网。仿真可以停止了</p><p>getLastStepMeanSpeed(string) -&gt; double 可以获取车辆平均行驶速度</p><p>getLastStepVehicleIDs(string) -&gt; list(string) 获取通过感应线圈的车辆ID </p><p>traci.inductionloop.getLastStepVehicleNumber(string) -&gt; integer 最近一次仿真步里，指定线圈上通过车辆的数量，判断某个方向是否有车通过</p><p>traci.edge.getLastStepVehicleNumber(string) -&gt; integer 某个路段通过车辆的数量</p></blockquote><p>traci下面的vehicle类 访问车辆动作的函数：</p><blockquote><p>changeTarget(string, string) -&gt; None重新规划目的地道路</p><p>getAccel(string) -&gt; double 获取车辆加速度</p><p>getPosition(string) -&gt; (double, double) 获取车辆位置</p><p>isStopped(string) -&gt; bool 检测车辆是否停止</p><p>setAccel(string, double) -&gt; None设置车辆加速度</p><p>setMaxSpeed(string, double) -&gt; None 设置车辆最大速度</p><p>setStop(string, string, double, integer, double, integer, double, double) -&gt; None 设置停车时间</p></blockquote><p>行人动作函数：</p><blockquote><p>traci.edge.getLastStepPersonIDs(edge) 获取该edge上行人的ID</p><p>traci.person.getWaitingTime(ped) 获取指定行人的等待时间(s)</p></blockquote><p>订阅subscriptions：订阅可以被看作是一个<strong>用于检索变量的批处理模式</strong>。<strong>代替重复请求相同的变量</strong>，在每个时间步长中，你可以自动检索感兴趣的值。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> traci</span><br><span class="line"><span class="keyword">import</span> traci.constants <span class="keyword">as</span> tc</span><br><span class="line"> </span><br><span class="line">PORT = <span class="number">8813</span></span><br><span class="line">traci.init(PORT) </span><br><span class="line">traci.vehicle.subscribe(vehID, (tc.VAR_ROAD_ID, tc.VAR_LANEPOSITION))</span><br><span class="line"><span class="built_in">print</span> traci.vehicle.getSubscriptionResults(vehID)</span><br><span class="line"><span class="keyword">for</span> step <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">3</span>):</span><br><span class="line">   <span class="built_in">print</span> <span class="string">&quot;step&quot;</span>, step</span><br><span class="line">   traci.simulationStep()</span><br><span class="line">   <span class="built_in">print</span> traci.vehicle.getSubscriptionResults(vehID)</span><br><span class="line">traci.close()</span><br></pre></td></tr></table></figure><p>检索的值总是从最后一个时间步骤的，它是不可能检索旧的值。</p><h3 id="动力学模型"><a href="#动力学模型" class="headerlink" title="动力学模型"></a>动力学模型</h3><p>SUMO 中车辆动力学模型包括两方面</p><blockquote><p><strong>longitudinal model</strong>： 纵向动力学模型，描述车辆<strong>加速和减速</strong></p><p><strong>lateral model</strong>：横向动力学模型，描述<strong>车辆换道</strong></p></blockquote><p>纵向动力学模型方面，SUMO 主要用于研究车辆的外部行为、多车交互和交通流，对于单个车辆建模精度要求不高，可以近似<strong>看作质点</strong>。采用比较简单的 car-following model (跟车模型) 来描述车辆速度和位置变化规律。跟车模型分为两种情况：有前车和无前车。</p><p>无前车的情形，车辆保持为最大速度，这里最大速度要至少考虑三方面的因素。三个最大速度中的最小值：</p><ul><li>该类型车辆本身能够达到的最大物理速度</li><li>前一时刻速度经过最大加速之后在当前时刻所能达到的最大速度</li><li>当前行驶道路规定的最大速度</li></ul><p>有前车的情形，要计算安全的行驶速度，保证任何情况下（尤其是前车急刹车时）车辆不会相撞。不同的跟车模型主要区别就在于<strong>如何计算安全行驶速度</strong>。目前 SUMO 中采用的为改进的 Krauss model.</p><p>横向动力学模型方面，SUMO采用lane changing model变道模型，简单地说就是<strong>以决策树的方式设定诸多换道条件</strong>，只要满足某些条件，就进行相应的换道操作。默认的 lane changing model 是<strong>瞬间换道</strong>，即在一个 simulation step 中完成换道，直观地看就是车辆在<strong>两个车道之间瞬移</strong>。更加精细的模型包括SublaneModel和Simple Continous lane-change model。</p><h3 id="Krauss-model"><a href="#Krauss-model" class="headerlink" title="Krauss model"></a>Krauss model</h3><p>了解一下原始的 Krauss model 的建模思想。</p><p><img src="https://img-blog.csdnimg.cn/6c87e4f80197429ea493d925a39dcd4a.png" alt="在这里插入图片描述"></p><p>泰勒展开近似替代后，得到估算值：</p><p><img src="https://img-blog.csdnimg.cn/89725e553fea45049f59a6ab06d39767.png" alt="在这里插入图片描述"></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">double MSCFModel_KraussOrig1::vsafe(double gap, double predSpeed, double /* predMaxDecel */) const &#123;</span><br><span class="line">    ...</span><br><span class="line">     double vsafe = (double)(-<span class="number">1.</span> * myTauDecel + sqrt( myTauDecel * myTauDecel + (predSpeed * predSpeed) + (<span class="number">2.</span> * myDecel * gap) ));</span><br><span class="line">     <span class="keyword">assert</span>(vsafe &gt;= <span class="number">0</span>);</span><br><span class="line">     <span class="keyword">return</span> vsafe;</span><br><span class="line"> &#125;</span><br></pre></td></tr></table></figure><p>这一速度还不是最终车辆采用的跟车速度。与无前车情况类似，我们也要保证跟车速度不能超过允许的最大速度，因此要取安全速度和允许最大速度中的较小值.</p><p>改进模型与原始的 Krauss 模型的出发点是相同的：在保证不碰撞的前提下，车速尽量的快。但在计算安全速度方面，与原始 Krauss 完全不同.</p><p><strong>没有采用泰勒展开方式近似表达刹车距离函数</strong>，而是直接数值计算。 基本思想是找到一个安全跟车速度使得后车在此速度下刹车距离 (包括反应距离) 正好等于前车的刹车距离加上原本两车间距。</p><p><img src="https://img-blog.csdnimg.cn/e6a5f623b3a84090b0ed76a423b502b8.png" alt="在这里插入图片描述"></p><h3 id="lane-changing-model"><a href="#lane-changing-model" class="headerlink" title="lane changing model"></a>lane changing model</h3><p>道路车辆微观驾驶动力学是由以下几种模型的相互作用决定的：</p><ul><li>跟驰模型：根据前车的行为决定自身的速度。</li><li>交叉口通行模型：从通行权规则、间隙接受、避免路口堵塞等方面确定车辆在不同类型交叉口的行为。</li><li><strong>换道模型</strong>：决定在多车道道路的车道选择和换道时的速度调整。</li></ul><p>相比于其他的微观换道模型，该模型明确区分了四种不同的换道动机：</p><ul><li>Strategic change 战略变道：每当车辆<strong>必须换道</strong>以便于能够驶向其<strong>行驶路径</strong>的下一条道路。</li><li>Cooperative change 协同变道：帮助另一辆车辆换道到他们所在的车道</li><li>Tactical change 战术变道：车辆试图避免跟随缓慢前车的动作，平衡从换车道中获得的预期速度收益和换车道的努力</li><li>Obligatory change 义务变道：清除超车车道的强迫行为可以被定义为义务行为</li></ul><p> 汽车变道规划的四个子步骤：</p><ul><li>计算优选后继车道；</li><li>在保持当前车道的假设下，计算安全速度，并整合来自先前模拟步骤的车道变换相关速度请求；</li><li>车道变换模型计算变更请求（左，右，停留）；</li><li>执行换道操作或计算下一个模拟步骤的速度请求（包括提前计划多个步骤）。是否请求速度变化取决于变道请求的紧急程度；</li></ul><p>评估子线路的标准：</p><ul><li>bestLanes（不需要换道）</li><li>occupation（沿着最优道路的车辆密度）</li><li>bestLaneOffset（车道偏移量）</li></ul><p>评估换道行为的紧急程度：</p><p><img src="https://img-blog.csdnimg.cn/89d77c72255c4c229a0eafe69e3abff9.png" alt="在这里插入图片描述"> </p><p>探究vechicle 与blocking vehicle的关系，并根据两者之间的关系来相应地改变行为：</p><blockquote><p>每当由于阻挡车辆而不能执行期望的车道变换时，车辆可以调整其速度以允许车道变换在后续步骤中成功。 此外，车辆可能对阻挡车辆的速度产生影响（实际上，这通常作为对观察自我车辆的转向信号的反应而发生）。</p></blockquote><p>避免死锁：两车由于一些原因，同时到达道路的终点，此时两车都希望可以实现换道，这种情况便发生了死锁（deadlock）。</p><blockquote><p>为了避免这种情况，对车进行分类（更靠近道路终点的称为<code>blocking leader</code>，另一个称为<code>the blocking follower</code>)。后者要预先进行减速，以为前车留出足够的距离进行变道操作。尽管采取这种操作，死锁仍然可能无法避免，因为会存在多车道的情况。因此，采用的方法是<strong>预留出20~40m范围进行变道</strong>。</p></blockquote>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;SUMO ( Simulation of Urban Mobility) 是免费、开源的&lt;strong&gt;交通系统仿真软件&lt;/strong&gt;，可以实现交通流的微观控制，即具体到道路上每一辆车的运行路线都可以单独规划。可模拟复杂环境中的交通流。&lt;/p&gt;
&lt;p&gt; sumo中一个路网文件，分为路网net文件和交通需求（路径）route文件。net文件由node文件和edge文件组成。其中node表示节点，如一个交叉口。&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/60231f738f734b0398c3be55f291d444.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;blockquote&gt;
&lt;ol&gt;
&lt;li&gt;节点文件 node file (.nod.xml)&lt;/li&gt;
&lt;li&gt;连边文件 edge file (.edg.xml)&lt;/li&gt;
&lt;li&gt;类型文件 edge type file (.type.xml)&lt;/li&gt;
&lt;li&gt;基于上述三个文件创建路网文件 net file (.net.xml)&lt;/li&gt;
&lt;li&gt;路由文件 route file (.rou.xml)&lt;/li&gt;
&lt;/ol&gt;
&lt;/blockquote&gt;
&lt;p&gt;上述文件本质上都是 xml 文件，不过为了方便区分其作用，额外增加了一个后缀名。&lt;/p&gt;
&lt;p&gt;假设我们要创建如下图所示的小型道路网络：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/61d22eb0edea46178083dc319f043d56.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;图中黑色节点对应交通路口，连边对应道路。每个路口所在位置坐标已给出。&lt;/p&gt;
&lt;p&gt;node file&lt;/p&gt;
&lt;figure class=&quot;highlight xml&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;tag&quot;&gt;&amp;lt;&lt;span class=&quot;name&quot;&gt;nodes&lt;/span&gt;&amp;gt;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;span class=&quot;tag&quot;&gt;&amp;lt;&lt;span class=&quot;name&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;id&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;n1&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;x&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;-500&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;y&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;0&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;type&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;priority&amp;quot;&lt;/span&gt;/&amp;gt;&lt;/span&gt;   &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;span class=&quot;tag&quot;&gt;&amp;lt;&lt;span class=&quot;name&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;id&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;n2&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;x&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;-250&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;y&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;0&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;type&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;traffic_light&amp;quot;&lt;/span&gt;/&amp;gt;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;span class=&quot;tag&quot;&gt;&amp;lt;&lt;span class=&quot;name&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;id&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;n3&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;x&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;-150&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;y&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;200&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;type&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;traffic_light&amp;quot;&lt;/span&gt;/&amp;gt;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;span class=&quot;tag&quot;&gt;&amp;lt;&lt;span class=&quot;name&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;id&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;n4&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;x&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;0&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;y&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;0&amp;quot;&lt;/span&gt;/&amp;gt;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;span class=&quot;tag&quot;&gt;&amp;lt;&lt;span class=&quot;name&quot;&gt;node&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;id&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;n5&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;x&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;150&amp;quot;&lt;/span&gt; &lt;span class=&quot;attr&quot;&gt;y&lt;/span&gt;=&lt;span class=&quot;string&quot;&gt;&amp;quot;200&amp;quot;&lt;/span&gt;/&amp;gt;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;tag&quot;&gt;&amp;lt;/&lt;span class=&quot;name&quot;&gt;nodes&lt;/span&gt;&amp;gt;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/1392291413.html"/>
    <id>https://blog.aivgg.com/posts/1392291413.html</id>
    <published>2026-06-10T16:18:36.418Z</published>
    <updated>2026-06-10T16:24:04.381Z</updated>
    
    <content type="html"><![CDATA[<h3 id="Ray"><a href="#Ray" class="headerlink" title="Ray"></a>Ray</h3><p>Ray是一个用于并行和分布式 Python 的开源项目，当我们将应用程序迁移到分布式设置时，传统编程概念会发生变化。比如用于模型培训的 TensorFlow、用于数据处理和 SQL 的 Spark 以及用于流处理的 Flink。这些工具提供<strong>更高层次的抽象，如神经网络、数据集和流</strong>。但是，由于它们与串行编程所使用的抽象不同，因此<strong>必须重新编写应用程序</strong>以利用它们。</p><p><img src="https://img-blog.csdnimg.cn/682bfc588fe9442a8dab831a23e00800.png" alt="在这里插入图片描述"></p><p>Ray占据了一个独特的<strong>中间地带</strong>。而不是引入新的概念。Ray 获取函数和类的现有概念，并将它们作为任务和参与者转换为分布式设置。这种 <strong>API 选择允许串行应用程序并行化，而不需要进行重大修改</strong>。</p><p>Ray 可以用来在多个核心或机器上扩展 Python 应用。它有几个主要的优点，包括：</p><ul><li><p>简单性：你可以扩展你的 Python 应用，而不需要重写，同样的代码可以在一台机器或多台机器上运行。</p></li><li><p>稳健性：应用程序可以优雅地处理机器故障和进程抢占。</p></li><li><p>性能：任务以毫秒级的延迟运行，可扩展到数万个内核，并以最小的序列化开销处理数值数据。</p></li></ul><p><img src="https://img-blog.csdnimg.cn/19c087b261e64effa0eab26efafaf4a1.png" alt="在这里插入图片描述"></p><p>作为分布式计算系统，Ray仍旧遵循了典型的Master-Slave的设计：Master负责全局协调和状态维护，Slave执行分布式计算任务。不过和传统的分布式计算系统不同的是，Ray使用了<strong>混合任务调度</strong>的思路。</p><ul><li><strong>GlobalScheduler</strong>：Master上启动了一个全局调度器，用于接收本地调度器提交的任务，并将任务分发给合适的本地任务调度器执行。</li><li><strong>RedisServer</strong>：Master上启动了一到多个RedisServer用于保存分布式任务的状态信息（ControlState），包括对象机器的映射、任务描述、任务debug信息等。</li><li><strong>LocalScheduler</strong>：每个Slave上启动了一个本地调度器，用于提交任务到全局调度器，以及分配任务给当前机器的Worker进程。</li><li><strong>Worker</strong>：每个Slave上可以启动多个Worker进程执行分布式任务，并将计算结果存储到ObjectStore。</li><li><strong>ObjectStore</strong>：每个Slave上启动了一个ObjectStore存储只读数据对象，Worker可以通过共享内存的方式访问这些对象数据，这样可以有效地减少内存拷贝和对象序列化成本。ObjectStore底层由Apache Arrow实现。</li><li><strong>Plasma</strong>：每个Slave上的ObjectStore都由一个名为Plasma的对象管理器进行管理，它可以在Worker访问本地ObjectStore上不存在的远程数据对象时，主动拉取其它Slave上的对象数据到当前机器</li></ul><p>Ray的任务也是通过类似<a href="https://so.csdn.net/so/search?q=Spark&spm=1001.2101.3001.7020">Spark</a>中<strong>Driver的概念的方式进行提交</strong>的，有所不同的是：</p><ul><li>Spark的Driver提交的是任务DAG，一旦提交则不可更改。</li><li>而Ray提交的是更细粒度的remote function，任务DAG依赖关系由函数依赖关系自由定制。</li></ul><h3 id="安装"><a href="#安装" class="headerlink" title="安装"></a>安装</h3><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">pip install --upgrade pip  </span><br><span class="line">pip install ray == 1.6.0  </span><br></pre></td></tr></table></figure><h3 id="使用"><a href="#使用" class="headerlink" title="使用"></a>使用</h3><p>1、ray.init() ,类似sparkSession</p><figure class="highlight kotlin"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> ray</span><br><span class="line">ray.<span class="keyword">init</span>()</span><br></pre></td></tr></table></figure><p>如果是直连已有的Ray集群，只需要指定RedisServer的地址即可。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">ray.init(redis_address=<span class="string">&quot;&lt;redis-address&gt;&quot;</span>)</span><br></pre></td></tr></table></figure><p>本地启动Ray时，可以看到Ray的WebUI的访问地址</p><p>2、ray.put()， 类似Spark RDD并行化</p><p>使用<code>ray.put()</code>可以将Python对象存入本地ObjectStore，并且异步返回一个唯一的ObjectID。通过该ID，Ray可以访问集群中任一个节点上的对象</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">f</span>(<span class="params">x</span>):</span><br><span class="line">    <span class="keyword">pass</span></span><br><span class="line"> </span><br><span class="line">x = <span class="string">&quot;hello&quot;</span></span><br><span class="line"> </span><br><span class="line"><span class="comment"># 对象x往ObjectStore拷贝里10次</span></span><br><span class="line">[f.remote(x) <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>)]</span><br><span class="line"> </span><br><span class="line"><span class="comment"># 对象x仅往ObjectStore拷贝1次</span></span><br><span class="line">x_id = ray.put(x)</span><br><span class="line">[f.remote(x_id) <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>)]</span><br></pre></td></tr></table></figure><p>3、ray.get()</p><p>使用<code>ray.get()</code>可以通过ObjectID获取ObjectStore内的对象并将之转换为Python对象。对于数组类型的对象，Ray使用共享内存机制减少数据的拷贝成本。而对于其它对象则需要将数据从ObjectStore拷贝到进程的堆内存中。</p><p>如果调用<code>ray.get()</code>操作时，对象尚未创建好，则get操作会阻塞，直到对象创建完成后返回。get操作的关键流程如下：</p><ul><li>Driver或者Worker进程首先到ObjectStore内请求ObjectID对应的对象数据。</li><li>如果本地ObjectStore没有对应的对象数据，本地对象管理器Plasma会检查Master上的对象表查看对象是否存储其它节点的ObjectStore。</li><li>如果对象数据在其它节点的ObjectStore内，Plasma会发送网络请求将对象数据拉到本地ObjectStore。</li><li>如果对象数据还没有创建好，Master会在对象创建完成后通知请求的Plasma读取。</li><li>如果对象数据已经被所有的ObjectStore移除（被LRU策略删除），本地调度器会根据任务血缘关系执行对象的重新创建工作。</li><li>一旦对象数据在本地ObjectStore可用，Driver或者Worker进程会通过共享内存的方式直接将对象内存区域映射到自己的进程地址空间中，并反序列化为Python对象。</li></ul><p><code>ray.get()</code>可以一次性读取多个对象的数据</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">result_ids = [ray.put(i) <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">10</span>)]</span><br><span class="line">ray.get(result_ids)  <span class="comment"># [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]</span></span><br></pre></td></tr></table></figure><p>4、@ray.remote</p><p>Ray中使用注解<code>@ray.remote</code>可以声明一个remote function。remote函数时Ray的基本任务调度单元，remote函数定义后会立即被序列化存储到RedisServer中，并且分配了一个唯一的ID，这样就保证了集群的所有节点都可以看到这个函数的定义。这样对remote函数定义有了一个潜在的要求，即remote函数内如果调用了其它的用户函数，<strong>则必须提前定义</strong>，否则remote函数无法找到对应的函数定义内容。</p><p>调用remote函数的关键流程如下：</p><ul><li>调用remote函数时，首先会创建一个任务对象，它包含了函数的ID、参数的ID或者值（Python的基本对象直接传值，复杂对象会先通过<code>ray.put()</code>操作存入ObjectStore然后返回ObjectID）、函数返回值对象的ID。</li><li>任务对象被发送到本地调度器。</li><li>本地调度器决定任务对象是在本地调度还是发送给全局调度器。如果任务对象的依赖（参数）在本地的ObejctStore已经存在且本地的CPU和GPU计算资源充足，那么本地调度器将任务分配给本地的WorkerProcess执行。否则，任务对象被发送给全局调度器并存储到任务表（TaskTable）中，全局调度器根据当前的任务状态信息决定将任务发给集群中的某一个本地调度器。</li><li>本地调度器收到任务对象后（来自本地的任务或者全局调度分配的任务），会将其放入一个任务队列中，等待计算资源和本地依赖满足后分配给WorkerProcess执行。</li><li>Worker收到任务对象后执行该任务，并将函数返回值存入ObjectStore，并更新Master的对象表（ObjectTable）信息。</li></ul><p><code>@ray.remote</code>注解有一个参数<code>num_return_vals</code>用于声明remote函数的返回值个数，基于此实现remote函数的多返回值机制</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@ray.remote(<span class="params">num_return_vals=<span class="number">2</span></span>)</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">f</span>():</span><br><span class="line">    <span class="keyword">return</span> <span class="number">1</span>, <span class="number">2</span></span><br><span class="line"> </span><br><span class="line">x_id, y_id = f.remote()</span><br><span class="line">ray.get(x_id)  <span class="comment"># 1</span></span><br><span class="line">ray.get(y_id)  <span class="comment"># 2</span></span><br></pre></td></tr></table></figure><p><code>@ray.remote</code>注解的另一个参数<code>num_gpus</code>可以为任务指定GPU的资源</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@ray.remote(<span class="params">num_gpus=<span class="number">1</span></span>)</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">gpu_method</span>():</span><br><span class="line">    <span class="keyword">return</span> <span class="string">&quot;This function is allowed to use GPUs &#123;&#125;.&quot;</span>.<span class="built_in">format</span>(ray.get_gpu_ids())</span><br></pre></td></tr></table></figure><p>5、ray.wait()</p><p><code>ray.wait()</code>操作支持批量的任务等待，基于此可以实现一次性获取多个ObjectID对应的数据。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 启动5个remote函数调用任务</span></span><br><span class="line">results = [f.remote(i) <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">5</span>)]</span><br><span class="line"><span class="comment"># 阻塞等待4个任务完成，超时时间为2.5s</span></span><br><span class="line">ready_ids, remaining_ids = ray.wait(results, num_returns=<span class="number">4</span>, timeout=<span class="number">2500</span>)</span><br></pre></td></tr></table></figure><p>上述例子中，results包含了5个ObjectID，使用<code>ray.wait</code>操作可以一直等待有4个任务完成后返回，并将完成的数据对象放在第一个list类型返回值内，未完成的ObjectID放在第二个list返回值内。如果设置了超时时间，那么在超时时间结束后仍未等到预期的返回值个数，则已超时完成时的返回值为准。</p><p>6、ray.error_info()</p><p>使用ray.error_info()可以获取任务执行时产生的错误信息。</p><p>7、Actor</p><p>Ray的remote函数只能处理无状态的计算需求，有状态的计算需求需要使用Ray的Actor实现。在Python的<strong>class定义前使用<code>@ray.remote</code>可以声明Actor</strong>。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Counter</span>(<span class="title class_ inherited__">object</span>):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="variable language_">self</span>.value = <span class="number">0</span></span><br><span class="line"> </span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">increment</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="variable language_">self</span>.value += <span class="number">1</span></span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.value</span><br></pre></td></tr></table></figure><p>使用如下方式创建Actor对象。</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">a1 = Counter.remote()</span><br><span class="line">a2 = Counter.remote()</span><br></pre></td></tr></table></figure><p>调用Actor对象的方法使用Actor</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">a1.increment.remote()  <span class="comment"># ray.get returns 1</span></span><br><span class="line">a2.increment.remote()  <span class="comment"># ray.get returns 1</span></span><br></pre></td></tr></table></figure><p>调用Actor对象的方法的流程为：</p><ul><li>首先创建一个任务。</li><li>该任务被Driver直接分配到创建该Actor对应的本地执行器执行，这个操作绕开了全局调度器（Worker是否也可以使用Actor直接分配任务尚存疑问）。</li><li>返回Actor方法调用结果的ObjectID。</li></ul><p>为了保证Actor状态的一致性，对<strong>同一个Actor的方法调用是串行执行的</strong>。</p><h3 id="RLlib"><a href="#RLlib" class="headerlink" title="RLlib"></a>RLlib</h3><p>RLlib是一个用于强化学习的开源库，它为各种应用程序提供了高可伸缩性(Scalable Reinforcement Learning)和统一API。RLlib本身支持TensorFlow、TensorFlow Eager和PyTorch，但它的大多数内部内容是框架无关的。<strong>RLlib之于Ray就如同MLlib之于Spark</strong>：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> ray importtune</span><br><span class="line"><span class="keyword">from</span> ray.rllib.agents.ppo importPPOTrainer</span><br><span class="line">tune.run(PPOTrainer, config=&#123;<span class="string">&quot;env&quot;</span>: <span class="string">&quot;CartPole-v0&quot;</span>&#125;)  <span class="comment">#&quot;log_level&quot;: &quot;INFO&quot; for verbose,</span></span><br></pre></td></tr></table></figure><p>上面三行代码就可以训练一个玩平衡杆游戏的智能体</p><p><img src="https://img-blog.csdnimg.cn/fe5fdeaa447246be8878b356277f4db3.png" alt="在这里插入图片描述"></p><p>最底层的分布式计算任务是由Ray引擎支撑的。倒数第二层表明RLlib是对特定的强化学习任务进行的抽象。第二层表示面向开发者，我们可以自定义算法。最顶层是RLlib对一些应用的支持，比如：可以让智能体在离线的数据、Gym或者Unit3d的环境中进行交互等等</p><p><strong>Policies</strong>,策略是RLlib中的核心概念.policies是定义agent 如何在环境中工作的Python类. Rollout workers查询策略以确定agent 的动作。在<a href="https://www.cnblogs.com/itmorn/p/13760678.html#openai-gym">gym</a> 中，只有一个agent 和policy。在<a href="https://www.cnblogs.com/itmorn/p/13760678.html#vectorized">vector envs</a>中，策略推理是针对多个代理的，在多代理中，可能有多个策略，每个策略控制一个或多个代理:</p><p><img src="https://img-blog.csdnimg.cn/216496df4acc4833abc7f27b5e4512d4.png" alt="在这里插入图片描述"></p><p><strong>Training</strong> 每个策略都定义了一个learn_on_batch()方法，该方法根据输入的样例批处理改进策略。对于TF和Torch策略，这是使用一个损失函数来实现的，该函数以样本批张量作为输入，并输出一个标量损失。</p><p>RLlib Trainer类协调分布式工作流（启动rollouts worker和策略优化）。它们利用Ray并行迭代器来实现所需的计算模式。下面的图显示了同步采样，这是这些模式中最简单的:</p><p><img src="https://img-blog.csdnimg.cn/ec5d8e6577b44e3eb0ab23200e610567.png" alt="在这里插入图片描述"></p><p>Trainer将数据广播给所有Workers，由他们与环境交互产生数据，经过抽样的方式返回Trainer进行训练。</p><p>RLlib使用Ray actor将训练从单个核扩展到集群中的数千个核。可以通过更改num_workers参数来配置用于培训的并行性。</p><p>RLlib几乎提供了自定义训练过程中所有方面的方法，包括环境(environment、神经网络模型(neural network model)、行动分布(action distribution)和策略定义(policy definitions):</p><p><img src="https://img-blog.csdnimg.cn/0809014f87c44290ba59a6e64d7da60d.png" alt="在这里插入图片描述"></p><p>超参数搜索库 Tune：</p><p>Ray Tune是一个用来实验执行和超参数调优的Python包，其中集成了网格搜索、随机搜索、贝叶斯优化搜索（BayesOptSearch）等搜索算法以及Optuna, Hyperopt等优化工具。Ray Tune调参的模型可以是基于PyTorch, XGBoost, TensorFlow或Keras等框架构建的模型。</p><p>安装</p><figure class="highlight shell"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">pip install &#x27;ray[tune]&#x27;</span><br></pre></td></tr></table></figure><p>使用tune，搜索lr的最佳超参值：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch.optim <span class="keyword">as</span> optim</span><br><span class="line"><span class="keyword">from</span> ray <span class="keyword">import</span> tune</span><br><span class="line"><span class="keyword">from</span> ray.tune.examples.mnist_pytorch <span class="keyword">import</span> get_data_loaders, ConvNet, train, test</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">train_mnist</span>(<span class="params">config</span>):</span><br><span class="line">    train_loader, test_loader = get_data_loaders()</span><br><span class="line">    model = ConvNet()</span><br><span class="line">    optimizer = optim.SGD(model.parameters(), lr=config[<span class="string">&quot;lr&quot;</span>])</span><br><span class="line">    <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">30</span>):</span><br><span class="line">        train(model, optimizer, train_loader)</span><br><span class="line">        acc = test(model, test_loader)</span><br><span class="line">        tune.track.log(mean_accuracy=acc)  <span class="comment"># 添加的代码</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 添加如下代码</span></span><br><span class="line">analysis = tune.run(</span><br><span class="line">    train_mnist,</span><br><span class="line">    num_samples=<span class="number">10</span>,</span><br><span class="line">    <span class="comment"># Uncomment this to let each evaluation use 1 GPU</span></span><br><span class="line">    <span class="comment"># resources_per_trial=&#123;&quot;CPU&quot;: 1, &quot;GPU&quot;: 1&#125;,</span></span><br><span class="line">    config=&#123;<span class="string">&quot;lr&quot;</span>: tune.grid_search([<span class="number">0.001</span>, <span class="number">0.01</span>, <span class="number">0.1</span>])&#125;)</span><br><span class="line"></span><br><span class="line"><span class="built_in">print</span>(<span class="string">&quot;Best config: &quot;</span>, analysis.get_best_config(metric=<span class="string">&quot;mean_accuracy&quot;</span>))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 获取结果的 dataframe</span></span><br><span class="line">df = analysis.dataframe()</span><br></pre></td></tr></table></figure><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">https://github.com/IntelLabs/coach</span><br><span class="line">https://github.com/cjy1992/gym-carla</span><br><span class="line">https://github.com/LovelyBuggies/sumo-gym</span><br><span class="line">https://github.com/SaloniDash7/gym-sumo</span><br><span class="line"></span><br><span class="line">https://github.com/LucasAlegre/sumo-rl</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">pip install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support</span><br><span class="line">pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support</span><br><span class="line"></span><br><span class="line">pip install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support</span><br><span class="line">pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support</span><br><span class="line"></span><br><span class="line"></span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br></pre></td><td class="code"><pre><span class="line"></span><br><span class="line"><span class="keyword">import</span> gym</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">from</span> gym <span class="keyword">import</span> spaces</span><br><span class="line"><span class="keyword">from</span> imitation.algorithms.adversarial.gail <span class="keyword">import</span> GAIL</span><br><span class="line"><span class="keyword">from</span> imitation.data <span class="keyword">import</span> rollout</span><br><span class="line"><span class="keyword">from</span> imitation.data.types <span class="keyword">import</span> Transitions</span><br><span class="line"><span class="keyword">from</span> imitation.data.wrappers <span class="keyword">import</span> RolloutInfoWrapper</span><br><span class="line"><span class="keyword">from</span> imitation.rewards.reward_nets <span class="keyword">import</span> BasicRewardNet</span><br><span class="line"><span class="keyword">from</span> imitation.util <span class="keyword">import</span> logger <span class="keyword">as</span> imit_logger</span><br><span class="line"><span class="keyword">from</span> imitation.util.networks <span class="keyword">import</span> RunningNorm</span><br><span class="line"><span class="keyword">from</span> stable_baselines3 <span class="keyword">import</span> PPO  <span class="comment"># DQN coming soon</span></span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.env_checker <span class="keyword">import</span> check_env</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.env_util <span class="keyword">import</span> make_vec_env</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.evaluation <span class="keyword">import</span> evaluate_policy</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.common.vec_env <span class="keyword">import</span> DummyVecEnv</span><br><span class="line"><span class="keyword">from</span> stable_baselines3.ppo <span class="keyword">import</span> MlpPolicy</span><br><span class="line"><span class="keyword">import</span> torch <span class="keyword">as</span> th</span><br><span class="line"></span><br><span class="line">log_dir = <span class="string">&quot;./tensorboard/Custom-Env&quot;</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">CustomEnv</span>(gym.Env):</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, max_steps=<span class="number">8</span></span>):</span><br><span class="line">        <span class="built_in">super</span>().__init__()</span><br><span class="line">        <span class="variable language_">self</span>.observation_space = spaces.Box(low=-<span class="number">1</span>, high=<span class="number">1</span>, shape=(<span class="number">2</span>,), dtype=np.float32)</span><br><span class="line">        <span class="variable language_">self</span>.action_space = spaces.Box(low=-<span class="number">1</span>, high=<span class="number">1</span>, shape=(<span class="number">2</span>,), dtype=np.float32)</span><br><span class="line">        <span class="variable language_">self</span>.max_steps = max_steps</span><br><span class="line">        <span class="variable language_">self</span>.n_steps = <span class="number">0</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="variable language_">self</span>.n_steps = <span class="number">0</span></span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.observation_space.sample()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">step</span>(<span class="params">self, action</span>):</span><br><span class="line">        <span class="variable language_">self</span>.n_steps += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">        done = <span class="literal">False</span></span><br><span class="line">        reward = <span class="number">0.0</span></span><br><span class="line">        <span class="keyword">if</span> <span class="variable language_">self</span>.n_steps &gt;= <span class="variable language_">self</span>.max_steps:</span><br><span class="line">            reward = <span class="number">1.0</span></span><br><span class="line">            done = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.observation_space.sample(), reward, done, &#123;&#125;</span><br><span class="line"></span><br><span class="line"><span class="comment"># 加载专家数据 从文件加载</span></span><br><span class="line"><span class="comment"># 暂时写成随机生成</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">load_expert_transitions</span>(<span class="params">env, length</span>):</span><br><span class="line">    obs = np.array([env.observation_space.sample() <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(length)])</span><br><span class="line">    acts = np.array([env.action_space.sample() <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(length)])</span><br><span class="line">    infos = np.array([&#123;i: i&#125; <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(length)])</span><br><span class="line">    next_obs = np.array([env.observation_space.sample() <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(length)])</span><br><span class="line">    dones = np.zeros(length, dtype=<span class="built_in">bool</span>)</span><br><span class="line">    <span class="keyword">return</span> Transitions(obs=obs, acts=acts, infos=infos, next_obs=next_obs, dones=dones)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">&quot;__main__&quot;</span>:</span><br><span class="line">    env = CustomEnv()</span><br><span class="line">    <span class="keyword">if</span> check_env(env):</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&quot;The Custom environment check done&quot;</span>)</span><br><span class="line">    device = th.device(<span class="string">&quot;cuda&quot;</span> <span class="keyword">if</span> th.cuda.is_available() <span class="keyword">else</span> <span class="string">&quot;cpu&quot;</span>)</span><br><span class="line">    <span class="built_in">print</span>(device)</span><br><span class="line">    <span class="comment"># 利用网络生成专家数据并采样</span></span><br><span class="line">    transitions = sample_expert_transitions()</span><br><span class="line">    <span class="comment"># 从文件加载</span></span><br><span class="line">    <span class="comment"># transitions = load_expert_transitions(env, 2048)</span></span><br><span class="line">    <span class="comment"># 生成GAIL训练网络</span></span><br><span class="line">    venv = make_vec_env(<span class="keyword">lambda</span>: env)</span><br><span class="line">    learner = PPO(</span><br><span class="line">        env=venv,</span><br><span class="line">        policy=MlpPolicy,</span><br><span class="line">        batch_size=<span class="number">64</span>,</span><br><span class="line">        ent_coef=<span class="number">0.0</span>,</span><br><span class="line">        learning_rate=<span class="number">0.0003</span>,</span><br><span class="line">        n_epochs=<span class="number">10</span>,</span><br><span class="line">        device=device,</span><br><span class="line">    )</span><br><span class="line">    reward_net = BasicRewardNet(</span><br><span class="line">        venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm</span><br><span class="line">    )</span><br><span class="line">    custom_logger = imit_logger.configure(</span><br><span class="line">        folder=log_dir,</span><br><span class="line">        format_strs=[<span class="string">&quot;tensorboard&quot;</span>, <span class="string">&quot;stdout&quot;</span>],</span><br><span class="line">    )</span><br><span class="line"></span><br><span class="line">    gail_trainer = GAIL(</span><br><span class="line">        demonstrations=transitions,</span><br><span class="line">        demo_batch_size=<span class="number">2</span>,</span><br><span class="line">        gen_replay_buffer_capacity=<span class="number">2048</span>,</span><br><span class="line">        n_disc_updates_per_round=<span class="number">4</span>,</span><br><span class="line">        venv=venv,</span><br><span class="line">        gen_algo=learner,</span><br><span class="line">        reward_net=reward_net,</span><br><span class="line">        log_dir=log_dir,</span><br><span class="line">        init_tensorboard=<span class="literal">False</span>,</span><br><span class="line">        init_tensorboard_graph=<span class="literal">False</span>,</span><br><span class="line">        custom_logger=custom_logger</span><br><span class="line">    )</span><br><span class="line"></span><br><span class="line">    learner_rewards_before_training, _ = evaluate_policy(</span><br><span class="line">        learner, venv, <span class="number">10</span>, return_episode_rewards=<span class="literal">True</span></span><br><span class="line">    )</span><br><span class="line">    gail_trainer.train(<span class="number">20000</span>)</span><br><span class="line">    learner_rewards_after_training, _ = evaluate_policy(</span><br><span class="line">        learner, venv, <span class="number">10</span>, return_episode_rewards=<span class="literal">True</span></span><br><span class="line">    )</span><br><span class="line">    <span class="comment"># # 对比训练前后奖励数据变化</span></span><br><span class="line">    <span class="built_in">print</span>(np.mean(learner_rewards_after_training))</span><br><span class="line">    <span class="built_in">print</span>(np.mean(learner_rewards_before_training))</span><br><span class="line"></span><br><span class="line">    plt.hist(</span><br><span class="line">        [learner_rewards_before_training, learner_rewards_after_training],</span><br><span class="line">        label=[<span class="string">&quot;untrained&quot;</span>, <span class="string">&quot;trained&quot;</span>],</span><br><span class="line">    )</span><br><span class="line">    plt.legend()</span><br><span class="line">    plt.show()</span><br><span class="line">    <span class="comment"># tensorboard --logdir ./tensorboard/Custom-Env</span></span><br><span class="line">    <span class="comment"># Export to ONNX</span></span><br><span class="line">    learner.save(<span class="string">&quot;./gail.model&quot;</span>)</span><br><span class="line">    model = learner.load(<span class="string">&quot;./gail.model&quot;</span>)</span><br><span class="line">    <span class="built_in">print</span>(model.predict(env.reset(), deterministic=<span class="literal">True</span>))</span><br><span class="line"></span><br></pre></td></tr></table></figure><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 参考flow和sumo-rl的实现</span></span><br><span class="line"><span class="comment"># Environment 包含TrafficLight</span></span><br><span class="line"></span><br><span class="line"><span class="variable language_">self</span>.simulation = TraCISimulation(<span class="variable language_">self</span>) <span class="comment"># TraCI 的参数 sumo 如sim_step、simulation time、GUI属性信息 </span></span><br><span class="line"><span class="comment"># 静态的 但是仿真时间是动态的 用于存储一些仿真参数</span></span><br><span class="line"></span><br><span class="line"><span class="variable language_">self</span>.network = TraCIKernelNetwork(<span class="variable language_">self</span>, sim_params) <span class="comment"># edge、node、edge_max_speed、edge_length、sumo configuration files Perform no action of value (networks are static) 静态的</span></span><br><span class="line"><span class="comment"># 分为：edges_dict, conn_dict</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="variable language_">self</span>.vehicle = TraCIVehicle(<span class="variable language_">self</span>, sim_params) <span class="comment">#车辆信息、__controlled_ids、到达的_num_arrived、previous_speeds、被控制的 self.__sumo_obs 车的位置</span></span><br><span class="line"></span><br><span class="line">tc.VAR_LANE_INDEX, tc.VAR_LANEPOSITION,</span><br><span class="line">            tc.VAR_ROAD_ID,</span><br><span class="line">            tc.VAR_SPEED,</span><br><span class="line">            tc.VAR_EDGES,</span><br><span class="line">            tc.VAR_POSITION,</span><br><span class="line">            tc.VAR_ANGLE,</span><br><span class="line">            tc.VAR_SPEED_WITHOUT_TRACI,</span><br><span class="line">            tc.VAR_FUELCONSUMPTION,</span><br><span class="line">            tc.VAR_DISTANCE</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="variable language_">self</span>.traffic_light = TraCITrafficLight(<span class="variable language_">self</span>) <span class="comment">#交通灯 traffic light data</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 真正参与构成强化学习训练的是</span></span><br><span class="line"></span><br><span class="line">step里面执行仿真 以及update</span><br><span class="line"></span><br><span class="line">https://github.com/zbzhu99/NGSIM_Imitation</span><br><span class="line">https://github.com/wsjeon/multiagent-gail/tree/e7dd75f0dee17e33e55d7f4e24d40649fd648cf3</span><br><span class="line">    </span><br><span class="line">    </span><br></pre></td></tr></table></figure>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;Ray&quot;&gt;&lt;a href=&quot;#Ray&quot; class=&quot;headerlink&quot; title=&quot;Ray&quot;&gt;&lt;/a&gt;Ray&lt;/h3&gt;&lt;p&gt;Ray是一个用于并行和分布式 Python 的开源项目，当我们将应用程序迁移到分布式设置时，传统编程概念会发生变化。比如用于模型培训的 TensorFlow、用于数据处理和 SQL 的 Spark 以及用于流处理的 Flink。这些工具提供&lt;strong&gt;更高层次的抽象，如神经网络、数据集和流&lt;/strong&gt;。但是，由于它们与串行编程所使用的抽象不同，因此&lt;strong&gt;必须重新编写应用程序&lt;/strong&gt;以利用它们。&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/682bfc588fe9442a8dab831a23e00800.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;Ray占据了一个独特的&lt;strong&gt;中间地带&lt;/strong&gt;。而不是引入新的概念。Ray 获取函数和类的现有概念，并将它们作为任务和参与者转换为分布式设置。这种 &lt;strong&gt;API 选择允许串行应用程序并行化，而不需要进行重大修改&lt;/strong&gt;。&lt;/p&gt;
&lt;p&gt;Ray 可以用来在多个核心或机器上扩展 Python 应用。它有几个主要的优点，包括：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;p&gt;简单性：你可以扩展你的 Python 应用，而不需要重写，同样的代码可以在一台机器或多台机器上运行。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;&lt;p&gt;稳健性：应用程序可以优雅地处理机器故障和进程抢占。&lt;/p&gt;
&lt;/li&gt;
&lt;li&gt;&lt;p&gt;性能：任务以毫秒级的延迟运行，可扩展到数万个内核，并以最小的序列化开销处理数值数据。&lt;/p&gt;
&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/19c087b261e64effa0eab26efafaf4a1.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;作为分布式计算系统，Ray仍旧遵循了典型的Master-Slave的设计：Master负责全局协调和状态维护，Slave执行分布式计算任务。不过和传统的分布式计算系统不同的是，Ray使用了&lt;strong&gt;混合任务调度&lt;/strong&gt;的思路。&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;GlobalScheduler&lt;/strong&gt;：Master上启动了一个全局调度器，用于接收本地调度器提交的任务，并将任务分发给合适的本地任务调度器执行。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;RedisServer&lt;/strong&gt;：Master上启动了一到多个RedisServer用于保存分布式任务的状态信息（ControlState），包括对象机器的映射、任务描述、任务debug信息等。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;LocalScheduler&lt;/strong&gt;：每个Slave上启动了一个本地调度器，用于提交任务到全局调度器，以及分配任务给当前机器的Worker进程。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Worker&lt;/strong&gt;：每个Slave上可以启动多个Worker进程执行分布式任务，并将计算结果存储到ObjectStore。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;ObjectStore&lt;/strong&gt;：每个Slave上启动了一个ObjectStore存储只读数据对象，Worker可以通过共享内存的方式访问这些对象数据，这样可以有效地减少内存拷贝和对象序列化成本。ObjectStore底层由Apache Arrow实现。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Plasma&lt;/strong&gt;：每个Slave上的ObjectStore都由一个名为Plasma的对象管理器进行管理，它可以在Worker访问本地ObjectStore上不存在的远程数据对象时，主动拉取其它Slave上的对象数据到当前机器&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Ray的任务也是通过类似&lt;a href=&quot;https://so.csdn.net/so/search?q=Spark&amp;spm=1001.2101.3001.7020&quot;&gt;Spark&lt;/a&gt;中&lt;strong&gt;Driver的概念的方式进行提交&lt;/strong&gt;的，有所不同的是：&lt;/p&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/2847873748.html"/>
    <id>https://blog.aivgg.com/posts/2847873748.html</id>
    <published>2026-06-10T16:18:36.417Z</published>
    <updated>2026-06-10T16:24:04.380Z</updated>
    
    <content type="html"><![CDATA[<h3 id="LeNet-5"><a href="#LeNet-5" class="headerlink" title="LeNet-5"></a>LeNet-5</h3><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> init</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">import</span> torchvision.transforms <span class="keyword">as</span> transforms</span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"> </span><br><span class="line">os.environ[<span class="string">&quot;KMP_DUPLICATE_LIB_OK&quot;</span>]=<span class="string">&quot;TRUE&quot;</span></span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="comment"># 导入FashionMNIST数据集</span></span><br><span class="line">mnist_train = torchvision.datasets.FashionMNIST(root=<span class="string">&#x27;~/Datasets/FashionMNIST&#x27;</span>, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=transforms.ToTensor())</span><br><span class="line">mnist_test = torchvision.datasets.FashionMNIST(root=<span class="string">&#x27;~/Datasets/FashionMNIST&#x27;</span>, train=<span class="literal">False</span>, download=<span class="literal">True</span>, transform=transforms.ToTensor())</span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="comment"># 处理数据集，把数据转换成张量，使数据可以输入下面我们搭建的网络</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">load_data_fashion_mnist</span>(<span class="params">mnist_train, mnist_test, batch_size</span>):</span><br><span class="line">    <span class="keyword">if</span> sys.platform.startswith(<span class="string">&#x27;win&#x27;</span>):</span><br><span class="line">        num_workers = <span class="number">0</span></span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        num_workers = <span class="number">4</span></span><br><span class="line">    train_data = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=<span class="literal">True</span>, num_workers=num_workers)</span><br><span class="line">    test_data = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=<span class="literal">False</span>, num_workers=num_workers)</span><br><span class="line">    <span class="keyword">return</span> train_data, test_data</span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="keyword">class</span> <span class="title class_">LeNet</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(LeNet, <span class="variable language_">self</span>).__init__()</span><br><span class="line">        <span class="variable language_">self</span>.conv = nn.Sequential(</span><br><span class="line">            nn.Conv2d(in_channels=<span class="number">1</span>, out_channels=<span class="number">6</span>, kernel_size=<span class="number">5</span>), <span class="comment"># in_channels, out_channels, kernel_size</span></span><br><span class="line">            nn.LeakyReLU(<span class="number">0.1</span>),</span><br><span class="line">            nn.MaxPool2d(<span class="number">2</span>, <span class="number">2</span>), <span class="comment"># kernel_size, stride</span></span><br><span class="line">            nn.Conv2d(<span class="number">6</span>, <span class="number">16</span>, <span class="number">5</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.1</span>),</span><br><span class="line">            nn.MaxPool2d(<span class="number">2</span>, <span class="number">2</span>)</span><br><span class="line">        )</span><br><span class="line">        <span class="variable language_">self</span>.fc = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">16</span>*<span class="number">4</span>*<span class="number">4</span>, <span class="number">120</span>),</span><br><span class="line">            nn.LeakyReLU(<span class="number">0.1</span>),</span><br><span class="line">            nn.Linear(<span class="number">120</span>, <span class="number">84</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Linear(<span class="number">84</span>, <span class="number">10</span>)</span><br><span class="line">        )</span><br><span class="line"> </span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, img</span>):</span><br><span class="line">        feature = <span class="variable language_">self</span>.conv(img)</span><br><span class="line">        output = <span class="variable language_">self</span>.fc(feature.view(img.shape[<span class="number">0</span>], -<span class="number">1</span>))</span><br><span class="line">        <span class="keyword">return</span> output</span><br><span class="line"> </span><br><span class="line"><span class="comment"># 测试准确率计算</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">evaluate_accuracy</span>(<span class="params">data_iter, net, device=<span class="literal">None</span></span>):</span><br><span class="line">    <span class="keyword">if</span> device <span class="keyword">is</span> <span class="literal">None</span> <span class="keyword">and</span> <span class="built_in">isinstance</span>(net, torch.nn.Module):</span><br><span class="line">        <span class="comment"># 如果没指定device就使用net的device</span></span><br><span class="line">        device = <span class="built_in">list</span>(net.parameters())[<span class="number">0</span>].device</span><br><span class="line">    acc_sum, n = <span class="number">0.0</span>, <span class="number">0</span></span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> X, y <span class="keyword">in</span> data_iter:</span><br><span class="line">            net.<span class="built_in">eval</span>()  <span class="comment"># 评估模式, 这会关闭dropout</span></span><br><span class="line">            acc_sum += (net(X.to(device)).argmax(dim=<span class="number">1</span>) == y.to(device)).<span class="built_in">float</span>().<span class="built_in">sum</span>().cpu().item()</span><br><span class="line">            net.train()  <span class="comment"># 改回训练模式</span></span><br><span class="line">            n += y.shape[<span class="number">0</span>]</span><br><span class="line">    <span class="keyword">return</span> acc_sum / n</span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="comment"># 训练函数</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">train</span>(<span class="params">net, train_data, test_data, batch_size, optimizer, device, num_epochs</span>):</span><br><span class="line">    net = net.to(device)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;training on &quot;</span>, device)</span><br><span class="line">    loss_function = torch.nn.CrossEntropyLoss()   <span class="comment"># 定义损失函数（交叉熵损失函数）</span></span><br><span class="line">    ax = []  <span class="comment"># 保存等会更新的epoch，loss,train_acc,test_acc，用于绘制动态折线图</span></span><br><span class="line">    ay1 = []</span><br><span class="line">    ay2 = []</span><br><span class="line">    ay3 = []</span><br><span class="line">    plt.ion()</span><br><span class="line">    <span class="comment"># 开始训练</span></span><br><span class="line">    <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">        train_l_sum, train_acc_sum, n, batch_count, start = <span class="number">0.0</span>, <span class="number">0.0</span>, <span class="number">0</span>, <span class="number">0</span>, time.time()  <span class="comment"># 初始化参数</span></span><br><span class="line">        <span class="keyword">for</span> X, y <span class="keyword">in</span> train_data:</span><br><span class="line">            X = X.to(device)      <span class="comment"># 把参数导入GPU训练</span></span><br><span class="line">            y = y.to(device)</span><br><span class="line">            y_hat = net(X)</span><br><span class="line">            l = loss_function(y_hat, y)   <span class="comment"># 使用损失函数计算loss</span></span><br><span class="line">            optimizer.zero_grad() <span class="comment"># 把梯度置零，也就是把loss关于weight的导数变成0</span></span><br><span class="line">            l.backward()   <span class="comment"># 反向传播</span></span><br><span class="line">            optimizer.step()</span><br><span class="line">            train_l_sum += l.cpu().item()</span><br><span class="line">            train_acc_sum += (y_hat.argmax(dim=<span class="number">1</span>) == y).<span class="built_in">sum</span>().cpu().item()</span><br><span class="line">            n += y.shape[<span class="number">0</span>]</span><br><span class="line">            batch_count += <span class="number">1</span></span><br><span class="line">        test_acc = evaluate_accuracy(test_data, net)  <span class="comment"># 测试当个epoch的训练的网络</span></span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec&#x27;</span></span><br><span class="line">              % (epoch + <span class="number">1</span>, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))</span><br><span class="line">        <span class="comment"># 绘制动态折线图（如果不想绘制，可以删掉）</span></span><br><span class="line">        plt.clf()  <span class="comment"># 清除刷新前的图表，防止数据量过大消耗内存</span></span><br><span class="line">        ax.append(epoch + <span class="number">1</span>)  <span class="comment"># 追加x坐标值</span></span><br><span class="line">        ay1.append(train_l_sum / batch_count)  <span class="comment"># 追加y坐标值</span></span><br><span class="line">        ay2.append(train_acc_sum / n)</span><br><span class="line">        ay3.append(test_acc)</span><br><span class="line">        plt.plot(ax, ay1, <span class="string">&#x27;g-&#x27;</span>)</span><br><span class="line">        plt.plot(ax, ay2, <span class="string">&#x27;r-&#x27;</span>)</span><br><span class="line">        plt.plot(ax, ay3, <span class="string">&#x27;-&#x27;</span>)</span><br><span class="line">        plt.ylabel(<span class="string">&quot;epoch&quot;</span>)</span><br><span class="line">        plt.plot(ax, ay1, label=<span class="string">&quot;loss&quot;</span>)  <span class="comment"># 在绘图函数添加一个属性label</span></span><br><span class="line">        plt.plot(ax, ay2, label=<span class="string">&quot;train_acc&quot;</span>)</span><br><span class="line">        plt.plot(ax, ay3, label=<span class="string">&quot;test_acc&quot;</span>)</span><br><span class="line">        plt.legend(loc=<span class="number">2</span>)  <span class="comment"># 添加图例，loc为图例位置，1为右上角，2为左上角，3为左下角，4为右下角</span></span><br><span class="line">        plt.grid()   <span class="comment"># 添加网格</span></span><br><span class="line">        plt.pause(<span class="number">5</span>)  <span class="comment"># 设置暂停时间，太快图表无法正常显示</span></span><br><span class="line">        plt.ioff()  <span class="comment"># 关闭画图的窗口，即关闭交互模式</span></span><br><span class="line">    plt.show()  <span class="comment"># 显示图片，防止闪退</span></span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">&#x27;__main__&#x27;</span>:</span><br><span class="line">    batch_size = <span class="number">256</span>   <span class="comment"># 批量数大小</span></span><br><span class="line">    train_data, test_data = load_data_fashion_mnist(mnist_train, mnist_test, batch_size)</span><br><span class="line">    device = torch.device(<span class="string">&#x27;cuda&#x27;</span> <span class="keyword">if</span> torch.cuda.is_available() <span class="keyword">else</span> <span class="string">&#x27;cpu&#x27;</span>)  <span class="comment"># 使用GPU,如果没有则使用CPU</span></span><br><span class="line">    net = LeNet()    <span class="comment"># 导入我们搭建好的网络</span></span><br><span class="line">    lr, num_epochs = <span class="number">0.001</span>, <span class="number">10</span></span><br><span class="line">    optimizer = torch.optim.Adam(net.parameters(), lr=lr)  <span class="comment"># 优化函数</span></span><br><span class="line">    train(net, train_data, test_data, batch_size, optimizer, device, num_epochs)</span><br></pre></td></tr></table></figure><h3 id="AlexNet"><a href="#AlexNet" class="headerlink" title="AlexNet"></a>AlexNet</h3><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">AlexNet</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(AlexNet, <span class="variable language_">self</span>).__init__()</span><br><span class="line">        <span class="variable language_">self</span>.conv = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">1</span>, <span class="number">96</span>, <span class="number">11</span>, <span class="number">4</span>), <span class="comment"># in_channels, out_channels, kernel_size, stride, padding</span></span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">2</span>), <span class="comment"># kernel_size, stride</span></span><br><span class="line">            <span class="comment"># 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数</span></span><br><span class="line">            nn.Conv2d(<span class="number">96</span>, <span class="number">256</span>, <span class="number">5</span>, <span class="number">1</span>, <span class="number">2</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">2</span>),</span><br><span class="line">            <span class="comment"># 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。</span></span><br><span class="line">            nn.Conv2d(<span class="number">256</span>, <span class="number">384</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Conv2d(<span class="number">384</span>, <span class="number">384</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Conv2d(<span class="number">384</span>, <span class="number">256</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">2</span>)</span><br><span class="line">        )</span><br><span class="line">        <span class="variable language_">self</span>.fc = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">256</span>*<span class="number">5</span>*<span class="number">5</span>, <span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Dropout(<span class="number">0.5</span>),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>, <span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Dropout(<span class="number">0.5</span>),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>, <span class="number">10</span>),</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, img</span>):</span><br><span class="line">        feature = <span class="variable language_">self</span>.conv(img)</span><br><span class="line">        output = <span class="variable language_">self</span>.fc(feature.view(img.shape[<span class="number">0</span>], -<span class="number">1</span>))</span><br><span class="line">        <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure><p>完整实现：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn, optim</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">import</span> sys</span><br><span class="line"></span><br><span class="line">device = torch.device(<span class="string">&#x27;cuda&#x27;</span> <span class="keyword">if</span> torch.cuda.is_available() <span class="keyword">else</span> <span class="string">&#x27;cpu&#x27;</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">load_data_fashion_mnist</span>(<span class="params">batch_size, resize=<span class="literal">None</span>, root=<span class="string">&#x27;~/Datasets/FashionMNIST&#x27;</span></span>):</span><br><span class="line">    <span class="keyword">if</span> sys.platform.startswith(<span class="string">&#x27;win&#x27;</span>):</span><br><span class="line">        num_workers = <span class="number">0</span></span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        num_workers = <span class="number">4</span></span><br><span class="line">    trans = []</span><br><span class="line">    <span class="keyword">if</span> resize:</span><br><span class="line">        trans.append(torchvision.transforms.Resize(size=resize))</span><br><span class="line">    trans.append(torchvision.transforms.ToTensor())</span><br><span class="line"></span><br><span class="line">    transform = torchvision.transforms.Compose(trans)</span><br><span class="line">    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=<span class="literal">True</span>, download=<span class="literal">True</span>, transform=transform)</span><br><span class="line">    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=<span class="literal">False</span>, download=<span class="literal">True</span>, transform=transform)</span><br><span class="line"></span><br><span class="line">    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=<span class="literal">True</span>, num_workers=num_workers)</span><br><span class="line">    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=<span class="literal">False</span>, num_workers=num_workers)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">return</span> train_iter, test_iter</span><br><span class="line"></span><br><span class="line">batch_size = <span class="number">128</span></span><br><span class="line">train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=<span class="number">224</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">AlexNet</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="built_in">super</span>(AlexNet, <span class="variable language_">self</span>).__init__()</span><br><span class="line">        <span class="variable language_">self</span>.conv = nn.Sequential(</span><br><span class="line">            nn.Conv2d(<span class="number">1</span>, <span class="number">96</span>, <span class="number">11</span>, <span class="number">4</span>), <span class="comment"># in_channels, out_channels, kernel_size, stride, padding</span></span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">2</span>), <span class="comment"># kernel_size, stride</span></span><br><span class="line">            <span class="comment"># 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数</span></span><br><span class="line">            nn.Conv2d(<span class="number">96</span>, <span class="number">256</span>, <span class="number">5</span>, <span class="number">1</span>, <span class="number">2</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">2</span>),</span><br><span class="line">            <span class="comment"># 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。</span></span><br><span class="line">            nn.Conv2d(<span class="number">256</span>, <span class="number">384</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Conv2d(<span class="number">384</span>, <span class="number">384</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Conv2d(<span class="number">384</span>, <span class="number">256</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.MaxPool2d(<span class="number">3</span>, <span class="number">2</span>)</span><br><span class="line">        )</span><br><span class="line">        <span class="variable language_">self</span>.fc = nn.Sequential(</span><br><span class="line">            nn.Linear(<span class="number">256</span>*<span class="number">5</span>*<span class="number">5</span>, <span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Dropout(<span class="number">0.5</span>),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>, <span class="number">4096</span>),</span><br><span class="line">            nn.ReLU(),</span><br><span class="line">            nn.Dropout(<span class="number">0.5</span>),</span><br><span class="line">            nn.Linear(<span class="number">4096</span>, <span class="number">10</span>),</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, img</span>):</span><br><span class="line">        feature = <span class="variable language_">self</span>.conv(img)</span><br><span class="line">        output = <span class="variable language_">self</span>.fc(feature.view(img.shape[<span class="number">0</span>], -<span class="number">1</span>))</span><br><span class="line">        <span class="keyword">return</span> output</span><br><span class="line"></span><br><span class="line">net = AlexNet()</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">evaluate_accuracy</span>(<span class="params">data_iter, net, device=<span class="literal">None</span></span>):</span><br><span class="line">    <span class="keyword">if</span> device <span class="keyword">is</span> <span class="literal">None</span> <span class="keyword">and</span> <span class="built_in">isinstance</span>(net, torch.nn.Module):</span><br><span class="line">        <span class="comment"># 如果没指定device就使用net的device</span></span><br><span class="line">        device = <span class="built_in">list</span>(net.parameters())[<span class="number">0</span>].device</span><br><span class="line">    acc_sum, n = <span class="number">0.0</span>, <span class="number">0</span></span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> X, y <span class="keyword">in</span> data_iter:</span><br><span class="line">            net.<span class="built_in">eval</span>() <span class="comment"># 评估模式, 这会关闭dropout</span></span><br><span class="line">            acc_sum += (net(X.to(device)).argmax(dim=<span class="number">1</span>) == y.to(device)).<span class="built_in">float</span>().<span class="built_in">sum</span>().cpu().item()</span><br><span class="line">            net.train() <span class="comment"># 改回训练模式</span></span><br><span class="line">            n += y.shape[<span class="number">0</span>]</span><br><span class="line">    <span class="keyword">return</span> acc_sum / n</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">train</span>(<span class="params">net, train_iter, test_iter, batch_size, optimizer, device, num_epochs</span>):</span><br><span class="line">    net = net.to(device)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">&quot;training on &quot;</span>, device)</span><br><span class="line">    loss = torch.nn.CrossEntropyLoss()</span><br><span class="line">    <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">        train_l_sum, train_acc_sum, n, batch_count, start = <span class="number">0.0</span>, <span class="number">0.0</span>, <span class="number">0</span>, <span class="number">0</span>, time.time()</span><br><span class="line">        <span class="keyword">for</span> X, y <span class="keyword">in</span> train_iter:</span><br><span class="line">            X = X.to(device)</span><br><span class="line">            y = y.to(device)</span><br><span class="line">            y_hat = net(X)</span><br><span class="line">            l = loss(y_hat, y)</span><br><span class="line">            optimizer.zero_grad()</span><br><span class="line">            l.backward()</span><br><span class="line">            optimizer.step()</span><br><span class="line">            train_l_sum += l.cpu().item()</span><br><span class="line">            train_acc_sum += (y_hat.argmax(dim=<span class="number">1</span>) == y).<span class="built_in">sum</span>().cpu().item()</span><br><span class="line">            n += y.shape[<span class="number">0</span>]</span><br><span class="line">            batch_count += <span class="number">1</span></span><br><span class="line">        test_acc = evaluate_accuracy(test_iter, net)</span><br><span class="line">        <span class="built_in">print</span>(<span class="string">&#x27;epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec&#x27;</span></span><br><span class="line">              % (epoch + <span class="number">1</span>, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))</span><br><span class="line"></span><br><span class="line">lr, num_epochs = <span class="number">0.001</span>, <span class="number">5</span></span><br><span class="line">optimizer = torch.optim.Adam(net.parameters(), lr=lr)</span><br><span class="line">train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)</span><br><span class="line"></span><br></pre></td></tr></table></figure>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;LeNet-5&quot;&gt;&lt;a href=&quot;#LeNet-5&quot; class=&quot;headerlink&quot; title=&quot;LeNet-5&quot;&gt;&lt;/a&gt;LeNet-5&lt;/h3&gt;&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;53&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;54&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;55&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;56&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;57&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;58&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;59&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;60&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;61&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;62&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;63&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;64&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;65&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;66&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;67&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;68&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;69&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;70&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;71&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;72&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;73&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;74&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;75&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;76&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;77&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;78&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;79&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;80&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;81&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;82&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;83&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;84&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;85&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;86&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;87&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;88&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;89&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;90&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;91&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;92&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;93&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;94&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;95&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;96&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;97&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;98&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;99&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;100&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;101&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;102&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;103&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;104&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;105&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;106&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;107&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;108&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;109&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;110&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;111&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;112&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;113&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;114&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;115&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;116&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;117&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;118&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;119&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;120&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;121&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;122&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;123&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;124&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;125&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; torch &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; nn&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; torch.nn &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; init&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; numpy &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; np&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; sys&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torchvision&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torchvision.transforms &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; transforms&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; time&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; matplotlib.pyplot &lt;span class=&quot;keyword&quot;&gt;as&lt;/span&gt; plt&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; os&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;os.environ[&lt;span class=&quot;string&quot;&gt;&amp;quot;KMP_DUPLICATE_LIB_OK&amp;quot;&lt;/span&gt;]=&lt;span class=&quot;string&quot;&gt;&amp;quot;TRUE&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 导入FashionMNIST数据集&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;mnist_train = torchvision.datasets.FashionMNIST(root=&lt;span class=&quot;string&quot;&gt;&amp;#x27;~/Datasets/FashionMNIST&amp;#x27;&lt;/span&gt;, train=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, download=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, transform=transforms.ToTensor())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;mnist_test = torchvision.datasets.FashionMNIST(root=&lt;span class=&quot;string&quot;&gt;&amp;#x27;~/Datasets/FashionMNIST&amp;#x27;&lt;/span&gt;, train=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, download=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, transform=transforms.ToTensor())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 处理数据集，把数据转换成张量，使数据可以输入下面我们搭建的网络&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;load_data_fashion_mnist&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;mnist_train, mnist_test, batch_size&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; sys.platform.startswith(&lt;span class=&quot;string&quot;&gt;&amp;#x27;win&amp;#x27;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        num_workers = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        num_workers = &lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    train_data = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, num_workers=num_workers)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    test_data = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, num_workers=num_workers)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; train_data, test_data&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;LeNet&lt;/span&gt;(nn.Module):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;super&lt;/span&gt;(LeNet, &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;).__init__()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.conv = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(in_channels=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, out_channels=&lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;, kernel_size=&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;), &lt;span class=&quot;comment&quot;&gt;# in_channels, out_channels, kernel_size&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.LeakyReLU(&lt;span class=&quot;number&quot;&gt;0.1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;), &lt;span class=&quot;comment&quot;&gt;# kernel_size, stride&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;6&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;16&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.LeakyReLU(&lt;span class=&quot;number&quot;&gt;0.1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.fc = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;16&lt;/span&gt;*&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;*&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;120&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.LeakyReLU(&lt;span class=&quot;number&quot;&gt;0.1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;120&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;84&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;84&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;forward&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, img&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        feature = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.conv(img)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        output = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.fc(feature.view(img.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;], -&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; output&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 测试准确率计算&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;evaluate_accuracy&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;data_iter, net, device=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; device &lt;span class=&quot;keyword&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;isinstance&lt;/span&gt;(net, torch.nn.Module):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# 如果没指定device就使用net的device&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        device = &lt;span class=&quot;built_in&quot;&gt;list&lt;/span&gt;(net.parameters())[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;].device&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    acc_sum, n = &lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;with&lt;/span&gt; torch.no_grad():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; X, y &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; data_iter:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            net.&lt;span class=&quot;built_in&quot;&gt;eval&lt;/span&gt;()  &lt;span class=&quot;comment&quot;&gt;# 评估模式, 这会关闭dropout&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            acc_sum += (net(X.to(device)).argmax(dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) == y.to(device)).&lt;span class=&quot;built_in&quot;&gt;float&lt;/span&gt;().&lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;().cpu().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            net.train()  &lt;span class=&quot;comment&quot;&gt;# 改回训练模式&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            n += y.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; acc_sum / n&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 训练函数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;train&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;net, train_data, test_data, batch_size, optimizer, device, num_epochs&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    net = net.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;quot;training on &amp;quot;&lt;/span&gt;, device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    loss_function = torch.nn.CrossEntropyLoss()   &lt;span class=&quot;comment&quot;&gt;# 定义损失函数（交叉熵损失函数）&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ax = []  &lt;span class=&quot;comment&quot;&gt;# 保存等会更新的epoch，loss,train_acc,test_acc，用于绘制动态折线图&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ay1 = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ay2 = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    ay3 = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    plt.ion()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;comment&quot;&gt;# 开始训练&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; epoch &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(num_epochs):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        train_l_sum, train_acc_sum, n, batch_count, start = &lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, time.time()  &lt;span class=&quot;comment&quot;&gt;# 初始化参数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; X, y &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; train_data:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            X = X.to(device)      &lt;span class=&quot;comment&quot;&gt;# 把参数导入GPU训练&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            y = y.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            y_hat = net(X)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            l = loss_function(y_hat, y)   &lt;span class=&quot;comment&quot;&gt;# 使用损失函数计算loss&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            optimizer.zero_grad() &lt;span class=&quot;comment&quot;&gt;# 把梯度置零，也就是把loss关于weight的导数变成0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            l.backward()   &lt;span class=&quot;comment&quot;&gt;# 反向传播&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            optimizer.step()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            train_l_sum += l.cpu().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            train_acc_sum += (y_hat.argmax(dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) == y).&lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;().cpu().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            n += y.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            batch_count += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        test_acc = evaluate_accuracy(test_data, net)  &lt;span class=&quot;comment&quot;&gt;# 测试当个epoch的训练的网络&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;#x27;epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec&amp;#x27;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;              % (epoch + &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# 绘制动态折线图（如果不想绘制，可以删掉）&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.clf()  &lt;span class=&quot;comment&quot;&gt;# 清除刷新前的图表，防止数据量过大消耗内存&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ax.append(epoch + &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;)  &lt;span class=&quot;comment&quot;&gt;# 追加x坐标值&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ay1.append(train_l_sum / batch_count)  &lt;span class=&quot;comment&quot;&gt;# 追加y坐标值&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ay2.append(train_acc_sum / n)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        ay3.append(test_acc)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.plot(ax, ay1, &lt;span class=&quot;string&quot;&gt;&amp;#x27;g-&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.plot(ax, ay2, &lt;span class=&quot;string&quot;&gt;&amp;#x27;r-&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.plot(ax, ay3, &lt;span class=&quot;string&quot;&gt;&amp;#x27;-&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.ylabel(&lt;span class=&quot;string&quot;&gt;&amp;quot;epoch&amp;quot;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.plot(ax, ay1, label=&lt;span class=&quot;string&quot;&gt;&amp;quot;loss&amp;quot;&lt;/span&gt;)  &lt;span class=&quot;comment&quot;&gt;# 在绘图函数添加一个属性label&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.plot(ax, ay2, label=&lt;span class=&quot;string&quot;&gt;&amp;quot;train_acc&amp;quot;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.plot(ax, ay3, label=&lt;span class=&quot;string&quot;&gt;&amp;quot;test_acc&amp;quot;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.legend(loc=&lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;)  &lt;span class=&quot;comment&quot;&gt;# 添加图例，loc为图例位置，1为右上角，2为左上角，3为左下角，4为右下角&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.grid()   &lt;span class=&quot;comment&quot;&gt;# 添加网格&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.pause(&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;)  &lt;span class=&quot;comment&quot;&gt;# 设置暂停时间，太快图表无法正常显示&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        plt.ioff()  &lt;span class=&quot;comment&quot;&gt;# 关闭画图的窗口，即关闭交互模式&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    plt.show()  &lt;span class=&quot;comment&quot;&gt;# 显示图片，防止闪退&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt; &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; __name__ == &lt;span class=&quot;string&quot;&gt;&amp;#x27;__main__&amp;#x27;&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    batch_size = &lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;   &lt;span class=&quot;comment&quot;&gt;# 批量数大小&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    train_data, test_data = load_data_fashion_mnist(mnist_train, mnist_test, batch_size)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    device = torch.device(&lt;span class=&quot;string&quot;&gt;&amp;#x27;cuda&amp;#x27;&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; torch.cuda.is_available() &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;string&quot;&gt;&amp;#x27;cpu&amp;#x27;&lt;/span&gt;)  &lt;span class=&quot;comment&quot;&gt;# 使用GPU,如果没有则使用CPU&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    net = LeNet()    &lt;span class=&quot;comment&quot;&gt;# 导入我们搭建好的网络&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    lr, num_epochs = &lt;span class=&quot;number&quot;&gt;0.001&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    optimizer = torch.optim.Adam(net.parameters(), lr=lr)  &lt;span class=&quot;comment&quot;&gt;# 优化函数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    train(net, train_data, test_data, batch_size, optimizer, device, num_epochs)&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;



&lt;h3 id=&quot;AlexNet&quot;&gt;&lt;a href=&quot;#AlexNet&quot; class=&quot;headerlink&quot; title=&quot;AlexNet&quot;&gt;&lt;/a&gt;AlexNet&lt;/h3&gt;&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;AlexNet&lt;/span&gt;(nn.Module):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;super&lt;/span&gt;(AlexNet, &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;).__init__()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.conv = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;96&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;), &lt;span class=&quot;comment&quot;&gt;# in_channels, out_channels, kernel_size, stride, padding&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;), &lt;span class=&quot;comment&quot;&gt;# kernel_size, stride&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;96&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.fc = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;*&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;*&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Dropout(&lt;span class=&quot;number&quot;&gt;0.5&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Dropout(&lt;span class=&quot;number&quot;&gt;0.5&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;forward&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, img&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        feature = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.conv(img)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        output = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.fc(feature.view(img.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;], -&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; output&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;完整实现：&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;27&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;28&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;29&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;30&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;31&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;32&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;33&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;34&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;35&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;36&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;37&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;38&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;39&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;40&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;41&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;42&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;43&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;44&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;45&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;46&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;47&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;48&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;49&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;50&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;51&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;52&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;53&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;54&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;55&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;56&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;57&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;58&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;59&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;60&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;61&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;62&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;63&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;64&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;65&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;66&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;67&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;68&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;69&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;70&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;71&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;72&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;73&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;74&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;75&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;76&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;77&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;78&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;79&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;80&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;81&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;82&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;83&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;84&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;85&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;86&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;87&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;88&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;89&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;90&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;91&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;92&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;93&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;94&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;95&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;96&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;97&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;98&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;99&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;100&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;101&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;102&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;103&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;104&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;105&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;106&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;107&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; time&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torch&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;from&lt;/span&gt; torch &lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; nn, optim&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; torchvision&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;import&lt;/span&gt; sys&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;device = torch.device(&lt;span class=&quot;string&quot;&gt;&amp;#x27;cuda&amp;#x27;&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; torch.cuda.is_available() &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;string&quot;&gt;&amp;#x27;cpu&amp;#x27;&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;load_data_fashion_mnist&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;batch_size, resize=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;, root=&lt;span class=&quot;string&quot;&gt;&amp;#x27;~/Datasets/FashionMNIST&amp;#x27;&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; sys.platform.startswith(&lt;span class=&quot;string&quot;&gt;&amp;#x27;win&amp;#x27;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        num_workers = &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;else&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        num_workers = &lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    trans = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; resize:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        trans.append(torchvision.transforms.Resize(size=resize))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    trans.append(torchvision.transforms.ToTensor())&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    transform = torchvision.transforms.Compose(trans)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, download=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, transform=transform)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, download=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, transform=transform)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=&lt;span class=&quot;literal&quot;&gt;True&lt;/span&gt;, num_workers=num_workers)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=&lt;span class=&quot;literal&quot;&gt;False&lt;/span&gt;, num_workers=num_workers)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; train_iter, test_iter&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;batch_size = &lt;span class=&quot;number&quot;&gt;128&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=&lt;span class=&quot;number&quot;&gt;224&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;AlexNet&lt;/span&gt;(nn.Module):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;super&lt;/span&gt;(AlexNet, &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;).__init__()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.conv = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;96&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;11&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;), &lt;span class=&quot;comment&quot;&gt;# in_channels, out_channels, kernel_size, stride, padding&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;), &lt;span class=&quot;comment&quot;&gt;# kernel_size, stride&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# 减小卷积窗口，使用填充为2来使得输入与输出的高和宽一致，且增大输出通道数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;96&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            &lt;span class=&quot;comment&quot;&gt;# 连续3个卷积层，且使用更小的卷积窗口。除了最后的卷积层外，进一步增大了输出通道数。&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Conv2d(&lt;span class=&quot;number&quot;&gt;384&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.MaxPool2d(&lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;2&lt;/span&gt;)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.fc = nn.Sequential(&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;256&lt;/span&gt;*&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;*&lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Dropout(&lt;span class=&quot;number&quot;&gt;0.5&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.ReLU(),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Dropout(&lt;span class=&quot;number&quot;&gt;0.5&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            nn.Linear(&lt;span class=&quot;number&quot;&gt;4096&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;10&lt;/span&gt;),&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        )&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;forward&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, img&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        feature = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.conv(img)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        output = &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.fc(feature.view(img.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;], -&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; output&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;net = AlexNet()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;evaluate_accuracy&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;data_iter, net, device=&lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt;&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;if&lt;/span&gt; device &lt;span class=&quot;keyword&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;literal&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;keyword&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;isinstance&lt;/span&gt;(net, torch.nn.Module):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;comment&quot;&gt;# 如果没指定device就使用net的device&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        device = &lt;span class=&quot;built_in&quot;&gt;list&lt;/span&gt;(net.parameters())[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;].device&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    acc_sum, n = &lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;with&lt;/span&gt; torch.no_grad():&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; X, y &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; data_iter:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            net.&lt;span class=&quot;built_in&quot;&gt;eval&lt;/span&gt;() &lt;span class=&quot;comment&quot;&gt;# 评估模式, 这会关闭dropout&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            acc_sum += (net(X.to(device)).argmax(dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) == y.to(device)).&lt;span class=&quot;built_in&quot;&gt;float&lt;/span&gt;().&lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;().cpu().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            net.train() &lt;span class=&quot;comment&quot;&gt;# 改回训练模式&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            n += y.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; acc_sum / n&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;train&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;net, train_iter, test_iter, batch_size, optimizer, device, num_epochs&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    net = net.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;quot;training on &amp;quot;&lt;/span&gt;, device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    loss = torch.nn.CrossEntropyLoss()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; epoch &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;built_in&quot;&gt;range&lt;/span&gt;(num_epochs):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        train_l_sum, train_acc_sum, n, batch_count, start = &lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0.0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;, time.time()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;for&lt;/span&gt; X, y &lt;span class=&quot;keyword&quot;&gt;in&lt;/span&gt; train_iter:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            X = X.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            y = y.to(device)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            y_hat = net(X)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            l = loss(y_hat, y)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            optimizer.zero_grad()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            l.backward()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            optimizer.step()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            train_l_sum += l.cpu().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            train_acc_sum += (y_hat.argmax(dim=&lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;) == y).&lt;span class=&quot;built_in&quot;&gt;sum&lt;/span&gt;().cpu().item()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            n += y.shape[&lt;span class=&quot;number&quot;&gt;0&lt;/span&gt;]&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;            batch_count += &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        test_acc = evaluate_accuracy(test_iter, net)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;built_in&quot;&gt;print&lt;/span&gt;(&lt;span class=&quot;string&quot;&gt;&amp;#x27;epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec&amp;#x27;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;              % (epoch + &lt;span class=&quot;number&quot;&gt;1&lt;/span&gt;, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;lr, num_epochs = &lt;span class=&quot;number&quot;&gt;0.001&lt;/span&gt;, &lt;span class=&quot;number&quot;&gt;5&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;optimizer = torch.optim.Adam(net.parameters(), lr=lr)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;


    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/612344293.html"/>
    <id>https://blog.aivgg.com/posts/612344293.html</id>
    <published>2026-06-10T16:18:36.402Z</published>
    <updated>2026-06-10T16:24:04.378Z</updated>
    
    <content type="html"><![CDATA[<h3 id="背景"><a href="#背景" class="headerlink" title="背景"></a><strong>背景</strong></h3><p>GAIL算法存在的问题：</p><ul><li>模态崩塌问题：指生成模型产生的生成样本塌缩于真实样本分布的某一模态下的子分布，而无法覆盖全部真实样本分布。</li><li>生成样本利用效率低：是GAIL假设策略为随机性策略并以无模型RL方法来学习策略。由于随机性策略采样动作的过程是不可微分的，因此反向传播的链式求导在策略模型万的动作节点处中断。在随机环境中，智能体的状态迁移过程是随机的。</li></ul><p>InfoGAIL主要改进第一个问题，以自动驾驶为例，GAIL算法不能够很好处理不同驾驶风格的多专家数据场景。<strong>专家个体的不同，样本服从多个模态下的子分布，单一模态的假设不符合实际问题</strong>。</p><h3 id="InfoGAIL"><a href="#InfoGAIL" class="headerlink" title="InfoGAIL"></a>InfoGAIL</h3><p>核心思想：InfoGAIL假设专家数据具有多个模态的分布，从专家数据中同时学习多种有效的模态，比如快速驾驶模态与安全驾驶模态，<strong>增加辅助网络用来对样本所属的模态类别进行分类</strong>。InfoGAIL将信息论中的互信息概念运用到GAIL模型中，通过最大化互信息的原来，能增强策略产生的样本与模态隐变量之间的相关性，进而实现无监督的多模态学习。</p><p>互信息表示一个随机变量x在给定另一变量y后所减少的不确定性或信息量。通俗来说，互信息表示x与y之间的相关性，互信息越大，两者越相关。公式表示为：</p><p><img src="https://img-blog.csdnimg.cn/ca6cfff9e07b4a7b9ed483d1b73f25e0.png" alt="在这里插入图片描述"></p><p>InfoGAIL在GAIL的基础上考虑最大化待学习策略产生的状态-动作与模态隐变量之前的互信息：</p><p><img src="https://img-blog.csdnimg.cn/6f313f1cfbe34d47bf60d66f3ea9bd0d.png" alt="在这里插入图片描述"></p><p>具体的目标函数由原始的GAIL的目标函数引入互信息的惩罚性形成：</p><p><img src="https://img-blog.csdnimg.cn/9870f9cf83614535ab49d4524427fad6.png" alt="在这里插入图片描述"></p><p>由于缺少模态标签知识，互信息中的交叉熵无法直接计算，参考InfoGANs，将互信息放松为变分下界，并用网络模型Y近似后验概率。</p><p><img src="https://img-blog.csdnimg.cn/a14bddb47cbb4e02a112ce631ba13604.png" alt="在这里插入图片描述"></p><p>对比原始GAN、GAIL、InfoGAN</p><p><img src="https://img-blog.csdnimg.cn/ad806b0caedd4d309e8ea2092a834f76.png" alt="在这里插入图片描述"></p><p><img src="https://img-blog.csdnimg.cn/e7ee5a1bf8824515b6cd60161ac1033a.png" alt="在这里插入图片描述"></p><p><img src="https://img-blog.csdnimg.cn/5840e2a5ecb64dd8b3f8db1113bed76f.png" alt="在这里插入图片描述"></p><p>InfoGAIL的训练框架：</p><p><img src="https://img-blog.csdnimg.cn/53e06a44e71541488f6ac10734d09d0d.png" alt="在这里插入图片描述"></p><p>在InfoGAIL训练机制中，判别器Ｄ发挥着与原始GAIL中的D一样的功能，D引导π 产生的样本拟合专家样本分布.推断器Y以策略π 产生的（s，a）为输入，推断样本的后验概率.Y并不输入和处理专家样本．Y遵循互信息最大化的原理，不断改进自身的推断模型，从而解释出与π产生的样本相关程度最大的模态隐变量．Y引导策略产生与隐变量相关的状态-动作对.</p><p>除此之外还有两点优化：</p><p>1、Reward Augmentation :引入先验知识。考虑到专家策略本身是次优的，那么学习到的策略就到不了最优水平。引入一个基于状态的奖励函数：</p><p><img src="https://img-blog.csdnimg.cn/11a4f04c04c5438092cda967d690b045.png" alt="在这里插入图片描述"></p><p>2、Improved Optimization：优化高维输入任务的表现及避免GAN网络的梯度消失问题，采用WGAN框架：</p><p><img src="https://img-blog.csdnimg.cn/914819c8fcd34da4aaa1522ffb15414b.png" alt="在这里插入图片描述"></p><p>效果如下，BC算法的主要思想是直接克隆专家样本在各状态处的单步动作映射，BC会将细微的误差在序列的决策过程中逐步放大，GAIL算法假设所有的数据来源于一个专家，倾向于平均策略，InfoGAIL能够区分不同专家的行为：</p><p><img src="https://img-blog.csdnimg.cn/93824cf1469c4db1b9b2b7a41c80311f.png" alt="在这里插入图片描述"></p>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;背景&quot;&gt;&lt;a href=&quot;#背景&quot; class=&quot;headerlink&quot; title=&quot;背景&quot;&gt;&lt;/a&gt;&lt;strong&gt;背景&lt;/strong&gt;&lt;/h3&gt;&lt;p&gt;GAIL算法存在的问题：&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;模态崩塌问题：指生成模型产生的生成样本塌缩于真实样本分布的某一模态下的子分布，而无法覆盖全部真实样本分布。&lt;/li&gt;
&lt;li&gt;生成样本利用效率低：是GAIL假设策略为随机性策略并以无模型RL方法来学习策略。由于随机性策略采样动作的过程是不可微分的，因此反向传播的链式求导在策略模型万的动作节点处中断。在随机环境中，智能体的状态迁移过程是随机的。&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;InfoGAIL主要改进第一个问题，以自动驾驶为例，GAIL算法不能够很好处理不同驾驶风格的多专家数据场景。&lt;strong&gt;专家个体的不同，样本服从多个模态下的子分布，单一模态的假设不符合实际问题&lt;/strong&gt;。&lt;/p&gt;
&lt;h3 id=&quot;InfoGAIL&quot;&gt;&lt;a href=&quot;#InfoGAIL&quot; class=&quot;headerlink&quot; title=&quot;InfoGAIL&quot;&gt;&lt;/a&gt;InfoGAIL&lt;/h3&gt;&lt;p&gt;核心思想：InfoGAIL假设专家数据具有多个模态的分布，从专家数据中同时学习多种有效的模态，比如快速驾驶模态与安全驾驶模态，&lt;strong&gt;增加辅助网络用来对样本所属的模态类别进行分类&lt;/strong&gt;。InfoGAIL将信息论中的互信息概念运用到GAIL模型中，通过最大化互信息的原来，能增强策略产生的样本与模态隐变量之间的相关性，进而实现无监督的多模态学习。&lt;/p&gt;
&lt;p&gt;互信息表示一个随机变量x在给定另一变量y后所减少的不确定性或信息量。通俗来说，互信息表示x与y之间的相关性，互信息越大，两者越相关。公式表示为：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/ca6cfff9e07b4a7b9ed483d1b73f25e0.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;InfoGAIL在GAIL的基础上考虑最大化待学习策略产生的状态-动作与模态隐变量之前的互信息：&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/6f313f1cfbe34d47bf60d66f3ea9bd0d.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/2968666401.html"/>
    <id>https://blog.aivgg.com/posts/2968666401.html</id>
    <published>2026-06-10T16:18:36.401Z</published>
    <updated>2026-06-10T16:24:04.378Z</updated>
    
    <content type="html"><![CDATA[<p>Generative Adversarial Imitation Learning  – 这篇是首创GAIL的论文，数学比较多 大概看一下就行 不用读太深</p><p>GAIL-Imitating Driver Behavior with Generative Adversarial Networks 这篇是GAIl模型用在NPC车辆上的第一篇论文<br>Wasserstein-GAN  是GAIL的训练框架 GAN的变体<br>Info Gail是引入了隐变量，使一个模型可以训练出多种驾驶风格</p><p>RAIL是对强化学习的Reward Function进行了改动 对危险驾驶行为进行了惩罚</p><h3 id="模仿学习"><a href="#模仿学习" class="headerlink" title="模仿学习"></a>模仿学习</h3><p>模仿学习方法 通过模仿专家演示的样本以解决决策问题，它不需要从环境中获得奖赏反馈， 其反馈信息来自于专家的决策样本。在很多实际问题中，相较于设置合适的奖赏函数，获取专家样本往往更容易且代价更小。</p><p>模仿学习方法可以分为两类：行为克隆方法（Behavioral Cloning，简称 BC）和基于逆向强化学习的模仿学习方法（Imitation Learning via Inverse Reinforcement Learning,简称IRL-IL）</p><ul><li><strong>行为克隆方法的主要思想是直接克隆专家样本在各状态处的单步动作映射</strong>，即对专家样本进行监督学习．BC并不考虑当前状态之后的长远影响．在有足够多专家样本的前提下，它具有良好的表现．由于不考虑长远影响，BC会将细微的误差在序贯的决策过程中逐步放大，即产生级联误差问题。</li><li><strong>逆向强化学习假设专家策略等价于由未知的真实奖赏函数推导出的最优策略</strong>。逆向强化学习是RL的逆向过程，它根据给定的专家样本求解未知的奖赏函数．基于解的奖赏函数，通过RL方法求解最优策略的方式，间接地还原专家策略．这种模仿专家的方式使IRL-IL具备了长远规划的能力</li></ul><p>基于<strong>生成对抗网络的模仿学习方法</strong>（GANs-IL）从IRL-IL发展而来，是一 类结合了生成对抗网络的模仿学习方法．两者的主要区别是奖赏函数、策略的表示模型以及模型的训练方式．GANs-IL用<strong>两个神经网络来表示IRL-IL中的奖赏函数和策略，并用对抗的方式来优化这两个网络的参数</strong>．原始的生成对抗网络由生成模型（又称生成器）和判别模型（又称判别器）这两个相对抗的网络模型共同构成．</p><p><img src="https://img-blog.csdnimg.cn/8a42d26a8de848cd927298ac5af979ce.png" alt="在这里插入图片描述"></p><p>模仿学习的目标是<strong>学习得到与专家尽可能相似的决策模型</strong>．因此，模仿学习的评价标准一般为学习得到的策略与专家策略的性能对比。</p><p>在模仿学习中，获取专家样本集合的方式主要有以下两种：1)由人类专家示范而获得专家样本集合；2)通过强化学习方法对专家手工定义的标准奖赏函数学习得到贪婪策略，再由贪婪策略得到专家样本集合。然而，RL方法获得的贪婪策略可能不等价于最优策略。因而，这些由不同RL方法得到的贪婪策略的性能也各不相同。因此，<strong>通过RL方法得到的专家样本集合并没有形成标准</strong>。</p><p>目前，模仿学习问题多以仿真实验环境为主，如仿真小车、虚拟机器人控制等。对于不同的模仿学习任务，专家样本集合的获取方式并不固定。对于一些难度较大的模仿学习任务，标准的奖赏函数往往难以定义。因此，<strong>通过专家亲身示范行为动作获取专家样本集合的方式更为直接</strong>。对于一些存在危险的模仿学习任务，在虚拟环境中通过RL方法获得专家样本集合的方式更为恰当。</p><h3 id="GAIL"><a href="#GAIL" class="headerlink" title="GAIL"></a>GAIL</h3><p>Generative Adversarial Imitation Learning 最早出现且最具代表性的 GANs-IL方法,2016年提出.在 GAIL中，根据输入状态输出动作的策略可类比为生成器，而根据输入专家样本或生成样本输出奖赏值的奖赏函数可类比为判别器．从而，GAIL将求解奖赏函数的过程类比作判别器的训练过程，将策略的学习过程类比作生成器的训练过程.</p><p><img src="https://img-blog.csdnimg.cn/6d3e9476bd8c43aea2293a477f33a9fa.png" alt="在这里插入图片描述"></p><p>存在的问题：</p><ul><li>模态崩塌问题：指生成模型产生的生成样本塌缩于真实样本分布的某一模态下的子分布，而无法覆盖全部真实样本分布。以图片样本为例，模态崩塌将导致生成模型产生的图片样本只能表现出单幅画面或单一风格，而丧失了样本的多样性</li><li>生成样本利用效率低：是GAIL假设策略为随机性策略<strong>并以无模型RL方法来学习策略</strong>。由于随机性策略采样动作的过程是不可微分的，因此反向传播的链式求导在策略模型万的动作节点处中断。在随机环境中，智能体的状态迁移过程是随机的。</li></ul><h3 id="ACGAIL"><a href="#ACGAIL" class="headerlink" title="ACGAIL"></a>ACGAIL</h3><p>当专家样本服从多个模态下的子分布时，模仿学习的单一模态假设将导致模态崩塌。因此，假设专家具有多个模态的模仿学习方法更为合理。多模态的模仿学习放宽了单一模态的假设，<strong>它假设专家样本具有多个模态：专家演示的样本不限于单一模态而是来自不同模态下的多个子分布</strong>。基于多模态模仿学习的假设，GAIL的模态崩塌问题可以得到缓解。</p><p>在GAIL的基础上加入了辅助的网络模型，提出了带辅助分类器的生成对抗模仿学习（Generative Adversarial Imitation Leaming with Auxiliary Classifier,ACGAIL）,新的辅助网络用来对样本所属的模态类别进行分类，从而帮助原始GAIL的模型<strong>重构关于模态的条件信息</strong>.</p><p><img src="https://img-blog.csdnimg.cn/95a4d4eca572455da293c3e5b3b5b3ea.png" alt="在这里插入图片描述"></p><h3 id="InfoGAIL"><a href="#InfoGAIL" class="headerlink" title="InfoGAIL"></a>InfoGAIL</h3><p>基于互信息最大化的生成对抗模仿学习(Information Maximizing Generative Adversarial Imitation Learning,InfoGAIL)。InfoGAIL将信息论中的互信息概念运用到了GAIL中。通过最大化互信息的原理，<strong>InfoGAIL能增强策略产生的样本与模态隐变量之间的相关性，从而实现无监督的多模态学习。</strong></p><p><img src="https://img-blog.csdnimg.cn/9e90f1aed5794293a86329353a3f876b.png" alt="在这里插入图片描述"></p><p>ACGAIL与InfoGAIL两种方法的模态变量的先验分布假设是一致的。这两种方法均通过<strong>随机采样获得模态变量</strong>，且<strong>假设专家样本存在有限种模态，模态变量服从离散均匀分布</strong>。它们都在原始GAIL算法结构中<strong>引入了额外的分类模型</strong>，分别为分类器C和推断器Y。ACGAIL的分类器C能利用已有的模态标签进行<strong>有监督训练</strong>，而InfoGAIL的推断器能<strong>无监督地训练</strong>。不仅如此，分类器C和推断器Y均与判别器<strong>联合构成了奖赏函数</strong>。</p><h3 id="MAGAIL"><a href="#MAGAIL" class="headerlink" title="MAGAIL"></a>MAGAIL</h3><p>多智能体生成对抗模仿学习（Multi-Agent Genrative Adversarial Imitation Learning,MAGAIL）假设环境中存在ｋ个智能体，并有相应的ｋ个判别器．其中，每个判别器均对相应智能体的策略与该智能体的专家策略进行评分，并尽可能地给予专家策略较高的分值，同时给予智能体的策略较低的分值．每个智能体则尽可能产生能够欺骗判别器的行为，从而在判别器的引导下实现对专家策略的模仿学习.</p><p><strong>在多智能体的学习问题中，智能体相互之间的关系存在着一定的先验假设</strong>．比如，各个智能体之间存在着<strong>合作、竞争或相混合</strong>的假设.在不同的假设前提下，多智能体问题中的判别器存在不同的假设形式。</p><ul><li>集中式．当多智能体之间符合完全合作的关系时，MAGAIL中的智能体实际上共享着同一个判别器．此时，这种特殊情况可以被理解为原始的GAIL，而其学习得到的联合策略能够应用于所有智能体</li><li>分布式．当智能体之间没有存在奖赏的相关性假设时，每个智能体对应的判别器将采取各不相同的评分标准．然而，这些判别器由于不断地与环境进行间接的交互，它们相互之间也并非是完全独立的</li><li>零和博弈式．假设两个智能体之间处于完全竞争的关系，那么它们收到的奖赏互为相反数．在零和博弈中，智能体不需环境进行额外的交互，判别器直接对智能体与专家的联合样本进行判别训练。</li></ul><p><img src="https://img-blog.csdnimg.cn/40e89e24e8c04bea8a8d544a14258185.png" alt="在这里插入图片描述"></p><p>WARNING: The repository located at mirrors.aliyun.com is not a trusted or secure host and is being ignored. If this repository is available via HTTPS we recommend you use HTTPS instead, otherwise you may silence this warning and allow it anyway with ‘–trusted-host mirrors.aliyun.com’.<br>ERROR: Could not find a version that satisfies the requirement carla (from versions: none)<br>ERROR: No matching distribution found for carla</p>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;Generative Adversarial Imitation Learning  – 这篇是首创GAIL的论文，数学比较多 大概看一下就行 不用读太深&lt;/p&gt;
&lt;p&gt;GAIL-Imitating Driver Behavior with Generative Adversarial Networks 这篇是GAIl模型用在NPC车辆上的第一篇论文&lt;br&gt;Wasserstein-GAN  是GAIL的训练框架 GAN的变体&lt;br&gt;Info Gail是引入了隐变量，使一个模型可以训练出多种驾驶风格&lt;/p&gt;
&lt;p&gt;RAIL是对强化学习的Reward Function进行了改动 对危险驾驶行为进行了惩罚&lt;/p&gt;
&lt;h3 id=&quot;模仿学习&quot;&gt;&lt;a href=&quot;#模仿学习&quot; class=&quot;headerlink&quot; title=&quot;模仿学习&quot;&gt;&lt;/a&gt;模仿学习&lt;/h3&gt;&lt;p&gt;模仿学习方法 通过模仿专家演示的样本以解决决策问题，它不需要从环境中获得奖赏反馈， 其反馈信息来自于专家的决策样本。在很多实际问题中，相较于设置合适的奖赏函数，获取专家样本往往更容易且代价更小。&lt;/p&gt;
&lt;p&gt;模仿学习方法可以分为两类：行为克隆方法（Behavioral Cloning，简称 BC）和基于逆向强化学习的模仿学习方法（Imitation Learning via Inverse Reinforcement Learning,简称IRL-IL）&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;行为克隆方法的主要思想是直接克隆专家样本在各状态处的单步动作映射&lt;/strong&gt;，即对专家样本进行监督学习．BC并不考虑当前状态之后的长远影响．在有足够多专家样本的前提下，它具有良好的表现．由于不考虑长远影响，BC会将细微的误差在序贯的决策过程中逐步放大，即产生级联误差问题。&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;逆向强化学习假设专家策略等价于由未知的真实奖赏函数推导出的最优策略&lt;/strong&gt;。逆向强化学习是RL的逆向过程，它根据给定的专家样本求解未知的奖赏函数．基于解的奖赏函数，通过RL方法求解最优策略的方式，间接地还原专家策略．这种模仿专家的方式使IRL-IL具备了长远规划的能力&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;基于&lt;strong&gt;生成对抗网络的模仿学习方法&lt;/strong&gt;（GANs-IL）从IRL-IL发展而来，是一 类结合了生成对抗网络的模仿学习方法．两者的主要区别是奖赏函数、策略的表示模型以及模型的训练方式．GANs-IL用&lt;strong&gt;两个神经网络来表示IRL-IL中的奖赏函数和策略，并用对抗的方式来优化这两个网络的参数&lt;/strong&gt;．原始的生成对抗网络由生成模型（又称生成器）和判别模型（又称判别器）这两个相对抗的网络模型共同构成．&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/8a42d26a8de848cd927298ac5af979ce.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;模仿学习的目标是&lt;strong&gt;学习得到与专家尽可能相似的决策模型&lt;/strong&gt;．因此，模仿学习的评价标准一般为学习得到的策略与专家策略的性能对比。&lt;/p&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/3083270428.html"/>
    <id>https://blog.aivgg.com/posts/3083270428.html</id>
    <published>2026-06-10T16:18:36.243Z</published>
    <updated>2026-06-10T16:24:04.372Z</updated>
    
    <content type="html"><![CDATA[<h4 id="7、-基于规则和网络结合的智能模型构建技术"><a href="#7、-基于规则和网络结合的智能模型构建技术" class="headerlink" title="7、 基于规则和网络结合的智能模型构建技术"></a>7、 基于规则和网络结合的智能模型构建技术</h4><p>传统的规则智能体往往采用行为树或者状态机的方式进行决策，尽管在某些场景下能取得一定的效果，但是智能体对决策空间的探索度低，智能性不高；而纯粹的强化学习智能体尽管有很强的探索性以及探索最优解的能力，但却通常会遇到学习困难、效果不稳定、动作建模复杂等问题。</p><p>为了能更好的应用于复杂的J事场景，本项目采用了<strong>知识规则融合模型智能体决策体系</strong>来对博弈智能体进行构建。</p><h4 id="7-1-智能体构建框架"><a href="#7-1-智能体构建框架" class="headerlink" title="7.1  智能体构建框架"></a>7.1  智能体构建框架</h4><p>1、分层，高层智能体+有限状态机</p><p>在军事场景下，决策任务往往受到条令的限制，并且不同任务之间可能存在着先后顺序和依赖纠缠。为了更好地处理这种情况，我们将智能体的决策流程进行了分层设计。</p><p>在上层，有一个智能体负责粗粒度的决策。它的主要职责是判定任务类型和目标，并根据整体战略制定高级指令和计划，为下层智能体分配任务和优先级。</p><p>在下层，基于有限状态机，对任务进行了具体实现。有限状态机允许智能体根据当前的状态和环境条件，灵活地做出不同的决策，以适应任务需求。</p><p>这种分层设计使得决策系统更加有条理和灵活。上层智能体负责整体规划和任务分配，从宏观角度指导军事行动。而下层智能体通过有限状态机实现局部的决策，使得它们能够根据不同情况作出适时的反应。</p><p>由于整个战场决策分为预设任务与实时决策任务，同时为了在保证对抗效果的前提下最大化探索效率。最后实现的智能体结构如下</p><p><img src="/Users/jiangzhenjie/Desktop/doc_asset/agent2.png" alt="agent"></p><h4 id="7-2-规则任务框架"><a href="#7-2-规则任务框架" class="headerlink" title="7.2 规则任务框架"></a>7.2 规则任务框架</h4><p>前期规划</p><ul><li>初始进攻策略生成</li><li>粗粒度目标分配</li><li>静态任务规划</li></ul><p>实时决策</p><ul><li><p>兵力调度</p></li><li><p>任务执行</p></li></ul><p>即时分析</p><ul><li><p>细粒度单位调度</p></li><li><p>细粒度目标分配</p></li><li><p>动态任务生成</p></li></ul><p>为了更好地对场景任务进行抽象</p><p>通过分析智能体决策，将模型决策信息拆分为主语、谓语、宾语和其他信息</p><h5 id="基础规则构建-规则配置："><a href="#基础规则构建-规则配置：" class="headerlink" title="基础规则构建-规则配置："></a>基础规则构建-规则配置：</h5><p>一次性任务（任务可以对应为一条执行命令，执行完之后任务结束）</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="punctuation">&#123;</span></span><br><span class="line">  <span class="attr">&quot;unit_ids&quot;</span><span class="punctuation">:</span> <span class="punctuation">[</span><span class="string">&quot;DDG 81“温斯特.S.丘吉尔”阿里伯克级Flight IIA导弹驱逐舰&quot;</span><span class="punctuation">]</span><span class="punctuation">,</span>   # 主语，蓝方单位id</span><br><span class="line">  <span class="attr">&quot;type&quot;</span><span class="punctuation">:</span> <span class="string">&quot;NavalAsuWStrike_Naval&quot;</span><span class="punctuation">,</span>  # 谓语：舰对舰打击</span><br><span class="line">  <span class="attr">&quot;target_ids&quot;</span><span class="punctuation">:</span> <span class="punctuation">[</span><span class="string">&quot;003&quot;</span><span class="punctuation">]</span><span class="punctuation">,</span>  # 宾语，目标</span><br><span class="line"><span class="punctuation">&#125;</span></span><br></pre></td></tr></table></figure><p>持续性任务</p><p>一次性任务可以通过配置执行时间、重复次数转化为持续性任务</p><p>如：</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="punctuation">&#123;</span></span><br><span class="line">  <span class="attr">&quot;unit_ids&quot;</span><span class="punctuation">:</span> <span class="punctuation">[</span><span class="string">&quot;DDG 81“温斯特.S.丘吉尔”阿里伯克级Flight IIA导弹驱逐舰&quot;</span><span class="punctuation">]</span><span class="punctuation">,</span>   # 主语，蓝方单位id</span><br><span class="line">  <span class="attr">&quot;type&quot;</span><span class="punctuation">:</span> <span class="string">&quot;NavalAsuWStrike_Naval&quot;</span><span class="punctuation">,</span>  # 谓语：舰对舰打击</span><br><span class="line">  <span class="attr">&quot;target_ids&quot;</span><span class="punctuation">:</span> <span class="punctuation">[</span><span class="string">&quot;003&quot;</span><span class="punctuation">]</span><span class="punctuation">,</span>  # 宾语，目标</span><br><span class="line"><span class="punctuation">&#125;</span></span><br></pre></td></tr></table></figure><p>另一类持续性任务则对应需要多条执行命令来协同执行的任务</p><p>如空对舰打击，包含飞机起飞、移动、攻击、返航，需要按照执行阶段，对单位下发不同的指令，如</p><figure class="highlight json"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="punctuation">&#123;</span></span><br><span class="line">  <span class="attr">&quot;unit_ids&quot;</span><span class="punctuation">:</span> <span class="punctuation">[</span><span class="string">&quot;F35C-01&quot;</span><span class="punctuation">,</span> <span class="string">&quot;F35C-02&quot;</span><span class="punctuation">]</span><span class="punctuation">,</span>   # 主语，蓝方单位id</span><br><span class="line">  <span class="attr">&quot;type&quot;</span><span class="punctuation">:</span> <span class="string">&quot;AirIntercept&quot;</span><span class="punctuation">,</span>  # 谓语：空中拦截</span><br><span class="line">  <span class="attr">&quot;target_ids&quot;</span><span class="punctuation">:</span> <span class="punctuation">[</span><span class="string">&quot;J-15c-01&quot;</span><span class="punctuation">]</span><span class="punctuation">,</span>  # 宾语，目标</span><br><span class="line"><span class="punctuation">&#125;</span></span><br></pre></td></tr></table></figure><p>此外，在任务设计过程中，根据zz指令，还抽象出一些其他参数，用来实现特定的任务执行逻辑优化任务执行逻辑</p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">&quot;activation_time&quot;: &quot;2022-02-12&quot;,  # 任务激活时间</span><br><span class="line">&quot;course&quot;: [position_1, position_2, ...],  # 航线</span><br><span class="line">&quot;attack_mode&quot;: AttackMode.Repeat,  # 打击模式</span><br><span class="line">&quot;block&quot;: False,   # 是否锁定单位，其他任务不可替换</span><br></pre></td></tr></table></figure><p>任务有限状态机概览</p><p><img src="/Users/jiangzhenjie/Desktop/doc_asset/submission.png" alt="submission"></p><p>在复杂的对空拦截任务中，不仅需要处理基本的飞机编队选择、起飞、移动、攻击和返航等基本动作，还需要考虑许多偶发性事件，比如遇到敌方飞机、油量不足或目标消失等突发情况。为了在实际执行中保证任务的有效执行效果，并能够即时做出动态决策，引入了固定任务和触发式任务的设计。</p><p>在更新有限状态机的过程中，将任务执行流程细化，并为不同的任务状态设定相应的转换条件。这样，能够更好地控制飞机在不同情况下的行为，以适应任务的复杂性和多变性。</p><p>为了保证任务的执行效果，设计了固定任务，这些任务是在任务开始前就确定的，包括基本的飞行行动和攻击计划。这些固定任务保证了飞机在任务执行过程中的基本行为和目标导向。</p><p>同时，为了应对偶发性事件，引入了触发式任务。当飞机遇到敌方飞机、油量不足或目标消失等特定情况时，相应的触发式任务会被激活。这样，能够即时做出针对性的动态决策，并根据实时信息调整飞机的行动策略。</p><p>通过这种细化的任务规划和触发式任务的引入，决策系统在面对复杂多变的对空拦截任务时，能够保证任务执行的效果，并具备适应和应对不同突发情况的能力。这种灵活性和实时性的设计为任务的成功执行提供了有力的支持，也使得智能体能够更好地应对多样化的战场挑战。</p><p>任务交互逻辑</p><p><img src="/Users/jiangzhenjie/Desktop/doc_asset/mission.png" alt="mission"></p><h4 id="7-3-多智能体分目标协作"><a href="#7-3-多智能体分目标协作" class="headerlink" title="7.3 多智能体分目标协作"></a>7.3 多智能体分目标协作</h4><p>在决策过程中，每次与环境交互都会触发一次数据打包与解析的过程。为了加快训练，通常需要采用较高的决策间隔来调用环境。</p><p>在训练过程中发现，对于不同任务类型，决策间隔的敏感度存在明显差异。具体而言，对于防空任务，决策窗口期通常&lt;30s，决策对决策间隔极其敏感，过高的决策间隔会导致模型表现大幅下降。而在舰对舰打击任务中，决策对决策间隔的容忍度较高，允许较大的时间间隔。</p><p>然而，在单一模型的场景下，当决策间隔设置过低，例如每5秒一次，舰对舰打击任务中超过95%的action都将会由于武器耗尽&#x2F;冷却而被视为无效命令，这给模型训练带来了极大困难。</p><p>为了克服这一问题，我们将原本一个智能体的决策流程拆分成多个智能体。每个智能体都有独立的决策间隔设置，并可以与环境进行交互。通过这种多智能体决策的方法，有效缓解决策频率对训练过程的不利影响。</p><p><img src="/Users/jiangzhenjie/Desktop/doc_asset/time.png" alt="time"></p><h2 id="训练效果"><a href="#训练效果" class="headerlink" title="训练效果"></a>训练效果</h2><p>在对海场景想定（参考duitai_asuw2.1构建），模型决策间隔2分钟，每局产生200条数据，batch_size&#x3D;4096场景下，单次训练10w局周期3天。硬件条件：35核cpu core，1 gpu，20采样节点</p><p><img src="/Users/jiangzhenjie/Desktop/%E6%9D%90%E6%96%99/image/zc3%E5%A5%96%E5%8A%B1%E5%87%BD%E6%95%B0.png" alt="zc3奖励函数"></p><h2 id="任务规划"><a href="#任务规划" class="headerlink" title="任务规划"></a>任务规划</h2><blockquote><p>任务-场景对应</p></blockquote><table><thead><tr><th>命令类型</th><th>编队防空</th><th>海上防空</th><th>对海打击</th><th>模版-单位数量</th><th>模版-武器数量</th><th>初始航线规划</th></tr></thead><tbody><tr><td>空中打击-空对空</td><td></td><td></td><td>√</td><td>2</td><td>&#x2F;</td><td>避开威胁度最大区域</td></tr><tr><td>空中打击-空对海</td><td></td><td></td><td>√</td><td>2</td><td></td><td>避开威胁度最大区域</td></tr><tr><td></td><td></td><td></td><td></td><td></td><td></td><td></td></tr><tr><td>空中巡逻-空对空</td><td></td><td></td><td>√</td><td>2</td><td></td><td>前往威胁度最大区域</td></tr><tr><td></td><td></td><td></td><td></td><td></td><td></td><td></td></tr><tr><td>直接攻击-地对空（导弹）*</td><td>√</td><td></td><td></td><td></td><td>2</td><td></td></tr><tr><td>直接攻击-地对空（飞机）*</td><td>√</td><td></td><td></td><td></td><td>2</td><td></td></tr><tr><td>直接攻击-海对空（导弹）</td><td>√</td><td>√</td><td></td><td></td><td>2</td><td></td></tr><tr><td>直接攻击-海对空（飞机）</td><td>√</td><td>√</td><td></td><td></td><td>2</td><td></td></tr><tr><td>直接攻击-海对海</td><td></td><td></td><td>√</td><td></td><td>4&#x2F;8根据目标类型生成</td><td></td></tr><tr><td>待命-空动作</td><td>√</td><td>√</td><td>√</td><td></td><td></td><td></td></tr></tbody></table><p>编队场景将地面单位、海上单位视为同种类型单位进行联合调度</p>]]></content>
    
    <summary type="html">
    
      &lt;h4 id=&quot;7、-基于规则和网络结合的智能模型构建技术&quot;&gt;&lt;a href=&quot;#7、-基于规则和网络结合的智能模型构建技术&quot; class=&quot;headerlink&quot; title=&quot;7、 基于规则和网络结合的智能模型构建技术&quot;&gt;&lt;/a&gt;7、 基于规则和网络结合的智能模型构建技术&lt;/h4&gt;&lt;p&gt;传统的规则智能体往往采用行为树或者状态机的方式进行决策，尽管在某些场景下能取得一定的效果，但是智能体对决策空间的探索度低，智能性不高；而纯粹的强化学习智能体尽管有很强的探索性以及探索最优解的能力，但却通常会遇到学习困难、效果不稳定、动作建模复杂等问题。&lt;/p&gt;
&lt;p&gt;为了能更好的应用于复杂的J事场景，本项目采用了&lt;strong&gt;知识规则融合模型智能体决策体系&lt;/strong&gt;来对博弈智能体进行构建。&lt;/p&gt;
&lt;h4 id=&quot;7-1-智能体构建框架&quot;&gt;&lt;a href=&quot;#7-1-智能体构建框架&quot; class=&quot;headerlink&quot; title=&quot;7.1  智能体构建框架&quot;&gt;&lt;/a&gt;7.1  智能体构建框架&lt;/h4&gt;&lt;p&gt;1、分层，高层智能体+有限状态机&lt;/p&gt;
&lt;p&gt;在军事场景下，决策任务往往受到条令的限制，并且不同任务之间可能存在着先后顺序和依赖纠缠。为了更好地处理这种情况，我们将智能体的决策流程进行了分层设计。&lt;/p&gt;
&lt;p&gt;在上层，有一个智能体负责粗粒度的决策。它的主要职责是判定任务类型和目标，并根据整体战略制定高级指令和计划，为下层智能体分配任务和优先级。&lt;/p&gt;
&lt;p&gt;在下层，基于有限状态机，对任务进行了具体实现。有限状态机允许智能体根据当前的状态和环境条件，灵活地做出不同的决策，以适应任务需求。&lt;/p&gt;
&lt;p&gt;这种分层设计使得决策系统更加有条理和灵活。上层智能体负责整体规划和任务分配，从宏观角度指导军事行动。而下层智能体通过有限状态机实现局部的决策，使得它们能够根据不同情况作出适时的反应。&lt;/p&gt;
&lt;p&gt;由于整个战场决策分为预设任务与实时决策任务，同时为了在保证对抗效果的前提下最大化探索效率。最后实现的智能体结构如下&lt;/p&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/1991405205.html"/>
    <id>https://blog.aivgg.com/posts/1991405205.html</id>
    <published>2026-06-10T16:18:36.236Z</published>
    <updated>2026-06-10T16:24:04.385Z</updated>
    
    <content type="html"><![CDATA[{"name":"ai","services":{"ai-web":{"image":"uuv_web","ports":["36345:80"],"volumes":["D:/TianGong/ai/ai_ui/dist:/usr/share/nginx/html","./nginx-uuv.conf:/etc/nginx/conf.d/default.conf"]},"ai-mysql":{"image":"mysql:5.7.24","ports":["23306:3306"],"volumes":["./mysqld.cnf:/etc/mysql/mysql.conf.d/mysqld.cnf"],"environment":["MYSQL_ROOT_PASSWORD=123456","MYSQL_DATABASE=zcdb"],"command":["--character-set-server=utf8mb4","--collation-server=utf8mb4_general_ci","--skip-character-set-client-handshake"]},"ai-server1":{"image":"uuv_server_v2","volumes":["D:/TianGong/ai/ai_server/code/zcProject:/home/zc"],"ports":["38045:8000"],"command":["/bin/bash","-c","python /home/zc/manage.py makemigrations\npython /home/zc/manage.py migrate\npython /home/zc/manage.py runserver 0.0.0.0:8000\n"],"tty":true,"restart":"always","depends_on":["ai-mysql","ai-web"]}}}]]></content>
    
    <summary type="html">
    
      {&quot;name&quot;:&quot;ai&quot;,&quot;services&quot;:{&quot;ai-web&quot;:{&quot;image&quot;:&quot;uuv_web&quot;,&quot;ports&quot;:[&quot;36345:80&quot;],&quot;volumes&quot;:[&quot;D:/TianGong/ai/ai_ui/dist:/usr/share/nginx/html&quot;,&quot;./nginx-uuv.conf:/etc/nginx/conf.d/default.conf&quot;]},&quot;ai-mysql&quot;:{&quot;image&quot;:&quot;mysql:5.7.24&quot;,&quot;ports&quot;:[&quot;23306:3306&quot;],&quot;volumes&quot;:[&quot;./mysqld.cnf:/etc/mysql/mysql.conf.d/mysqld.cnf&quot;],&quot;environment&quot;:[&quot;MYSQL_ROOT_PASSWORD=123456&quot;,&quot;MYSQL_DATABASE=zcdb&quot;],&quot;command&quot;:[&quot;--character-set-server=utf8mb4&quot;,&quot;--collation-server=utf8mb4_general_ci&quot;,&quot;--skip-character-set-client-handshake&quot;]},&quot;ai-server1&quot;:{&quot;image&quot;:&quot;uuv_server_v2&quot;,&quot;volumes&quot;:[&quot;D:/TianGong/ai/ai_server/code/zcProject:/home/zc&quot;],&quot;ports&quot;:[&quot;38045:8000&quot;],&quot;command&quot;:[&quot;/bin/bash&quot;,&quot;-c&quot;,&quot;python /home/zc/manage.py makemigrations&#92;npython /home/zc/manage.py migrate&#92;npython /home/zc/manage.py runserver 0.0.0.0:8000&#92;n&quot;],&quot;tty&quot;:true,&quot;restart&quot;:&quot;always&quot;,&quot;depends_on&quot;:[&quot;ai-mysql&quot;,&quot;ai-web&quot;]}}}
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/2973059647.html"/>
    <id>https://blog.aivgg.com/posts/2973059647.html</id>
    <published>2026-06-10T16:18:36.235Z</published>
    <updated>2026-06-10T16:24:04.385Z</updated>
    
    <content type="html"><![CDATA[{"name":"ray","services":{"work":{"image":"ray_server_v2","environment":["NVIDIA_VISIBLE_DEVICES=all"],"restart":"always","network_mode":"host","command":"/bin/bash -c \"/root/miniconda3/envs/train_uuv/bin/ray start --address='192.168.2.2:6379' --block\""}}}]]></content>
    
    <summary type="html">
    
      {&quot;name&quot;:&quot;ray&quot;,&quot;services&quot;:{&quot;work&quot;:{&quot;image&quot;:&quot;ray_server_v2&quot;,&quot;environment&quot;:[&quot;NVIDIA_VISIBLE_DEVICES=all&quot;],&quot;restart&quot;:&quot;always&quot;,&quot;network_mode&quot;:&quot;host&quot;,&quot;command&quot;:&quot;/bin/bash -c &#92;&quot;/root/miniconda3/envs/train_uuv/bin/ray start --address=&#39;192.168.2.2:6379&#39; --block&#92;&quot;&quot;}}}
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/1166595107.html"/>
    <id>https://blog.aivgg.com/posts/1166595107.html</id>
    <published>2026-06-10T16:18:36.234Z</published>
    <updated>2026-06-10T16:24:04.384Z</updated>
    
    <content type="html"><![CDATA[{"name":"ray","services":{"head":{"image":"ray_server_v2","network_mode":"host","volumes":["/home/user/uuv/code/:/home"],"environment":["NVIDIA_VISIBLE_DEVICES=all"],"command":"/bin/bash -c \"/root/miniconda3/envs/train_uuv/bin/ray start --head --node-ip-address='0.0.0.0' --dashboard-host='0.0.0.0' --dashboard-port=8265 --block\"","tty":true}}}]]></content>
    
    <summary type="html">
    
      {&quot;name&quot;:&quot;ray&quot;,&quot;services&quot;:{&quot;head&quot;:{&quot;image&quot;:&quot;ray_server_v2&quot;,&quot;network_mode&quot;:&quot;host&quot;,&quot;volumes&quot;:[&quot;/home/user/uuv/code/:/home&quot;],&quot;environment&quot;:[&quot;NVIDIA_VISIBLE_DEVICES=all&quot;],&quot;command&quot;:&quot;/bin/bash -c &#92;&quot;/root/miniconda3/envs/train_uuv/bin/ray start --head --node-ip-address=&#39;0.0.0.0&#39; --dashboard-host=&#39;0.0.0.0&#39; --dashboard-port=8265 --block&#92;&quot;&quot;,&quot;tty&quot;:true}}}
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title></title>
    <link href="https://blog.aivgg.com/posts/3606857944.html"/>
    <id>https://blog.aivgg.com/posts/3606857944.html</id>
    <published>2026-06-10T16:18:36.130Z</published>
    <updated>2026-06-10T16:24:04.372Z</updated>
    
    <content type="html"><![CDATA[<h3 id="背景"><a href="#背景" class="headerlink" title="背景"></a>背景</h3><p>在强化学习解决问题的场景中，动作是体现学习效果最直接的因素，直接影响了智能体下一步的走向和对环境状态的改变。在应用强化学习解决实际问题时，往往不同于gym库中倒立摆那样的情况，而是存在很多的约束。例如，在t时刻智能体可选的动作为1,2,3，但是在t+1时刻只能选1,2.3处于不可用的状态。在这种情况下，就需要借助掩码mask来对智能体的动作进行处理。</p><p>有人会疑问：就不能制定相应的奖励函数使得智能体学习到这种约束吗？这样做是可以的，但是付出的训练代价很大，并且极其容易导致模型发散。因此，在大多数RL落地的场景下，都会使用MASK掩码方法解决动作约束的问题。</p><h3 id="MASK的方法"><a href="#MASK的方法" class="headerlink" title="MASK的方法"></a>MASK的方法</h3><p>Mask的核心就是在输出的动作或者值函数的向量上戴个“面具”，点乘一个{0,1}或者{−∞,1}的行向量，以规范化输出。这样智能体选出的动作就可以进行简单的规范化。</p><h3 id="MASK的两个关键点"><a href="#MASK的两个关键点" class="headerlink" title="MASK的两个关键点"></a>MASK的两个关键点</h3><p>由于强化学习，尤其是深度强化学习，<strong>学的最后还是分布</strong>，因此只是单单的不让智能体选择不符合规则的动作并不能加速模型的收敛。</p><p>因此，MASK一般加在选择动作前的值函数向量或者其他数据向量上，并且会将MASK后的值传入神经网络训练。<br>两个关键点分别是：</p><blockquote><p>1-mask分布 </p><p>2-回传训练</p></blockquote><h3 id="具体做法"><a href="#具体做法" class="headerlink" title="具体做法"></a>具体做法</h3><p>以openai中MASK星际争霸智能体的动作为例：首先是环境部分self.env，使用的是为每个agent提供一个available的动作集合，可以随时调用这个方法以获取agent此时的可执行动作：</p><p><img src="https://img-blog.csdnimg.cn/890214f3c13e48b59ad314e1acb77d56.png" alt="在这里插入图片描述"></p><p>然后在agent的动作选择阶段，使用inf代替不符合要求的部分，使得softmax选择的动作合理：</p><p><img src="https://img-blog.csdnimg.cn/dc127fe28ba3486e9084f64dadcc67c4.png" alt="在这里插入图片描述"></p><p>最后在policy学习更新的部分，同样利用-9999999作为不合理动作的替换，使得反向传播的概率分布与采样一致：</p><p><img src="https://img-blog.csdnimg.cn/ad473c4ccf94402b88bc18d8757a4a7b.png" alt="在这里插入图片描述"></p><p>在星际争霸游戏中，任何时刻，整个动作空间中只有一小部分子集的动作可以执行。为了防止 AI 在某些时刻选取当前时刻无法执行的动作，需要对动作空间进行 mask。具体操作时，<strong>如果选择了当前时刻不可用的动作，就会执行 no-op（no operation，即不操作）</strong></p><h3 id="实现"><a href="#实现" class="headerlink" title="实现"></a>实现</h3><p>第一步，自定义环境：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">MyParamActionEnv</span>(gym.Env):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, max_avail_actions</span>):</span><br><span class="line">        <span class="variable language_">self</span>.action_space = Discrete(max_avail_actions)</span><br><span class="line">        <span class="variable language_">self</span>.observation_space = <span class="type">Dict</span>(&#123;</span><br><span class="line">            <span class="string">&quot;action_mask&quot;</span>: Box(<span class="number">0</span>, <span class="number">1</span>, shape=(max_avail_actions, )), <span class="comment"># 添加action_mask 尺寸与action_space一致</span></span><br><span class="line">            <span class="string">&quot;avail_actions&quot;</span>: Box(-<span class="number">1</span>, <span class="number">1</span>, shape=(max_avail_actions, action_embedding_sz)),</span><br><span class="line">            <span class="string">&quot;real_obs&quot;</span>: ...,</span><br><span class="line">        &#125;)</span><br></pre></td></tr></table></figure><p>第二步，自定义网络：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ParametricActionsModel</span>(<span class="title class_ inherited__">TFModelV2</span>):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self,</span></span><br><span class="line"><span class="params">                 obs_space,</span></span><br><span class="line"><span class="params">                 action_space,</span></span><br><span class="line"><span class="params">                 num_outputs,</span></span><br><span class="line"><span class="params">                 model_config,</span></span><br><span class="line"><span class="params">                 name,</span></span><br><span class="line"><span class="params">                 true_obs_shape=(<span class="params"><span class="number">4</span>,</span>),</span></span><br><span class="line"><span class="params">                 action_embed_size=<span class="number">2</span></span>):</span><br><span class="line">        <span class="built_in">super</span>(ParametricActionsModel, <span class="variable language_">self</span>).__init__(</span><br><span class="line">            obs_space, action_space, num_outputs, model_config, name)</span><br><span class="line">        <span class="variable language_">self</span>.action_embed_model = FullyConnectedNetwork(...)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, input_dict, state, seq_lens</span>):</span><br><span class="line">        <span class="comment"># Extract the available actions tensor from the observation.</span></span><br><span class="line">        avail_actions = input_dict[<span class="string">&quot;obs&quot;</span>][<span class="string">&quot;avail_actions&quot;</span>]</span><br><span class="line">        action_mask = input_dict[<span class="string">&quot;obs&quot;</span>][<span class="string">&quot;action_mask&quot;</span>]</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Compute the predicted action embedding</span></span><br><span class="line">        action_embed, _ = <span class="variable language_">self</span>.action_embed_model(&#123;</span><br><span class="line">            <span class="string">&quot;obs&quot;</span>: input_dict[<span class="string">&quot;obs&quot;</span>][<span class="string">&quot;cart&quot;</span>]</span><br><span class="line">        &#125;)</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the</span></span><br><span class="line">        <span class="comment"># avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].</span></span><br><span class="line">        intent_vector = tf.expand_dims(action_embed, <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Batch dot product =&gt; shape of logits is [BATCH, MAX_ACTIONS].</span></span><br><span class="line">        action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=<span class="number">2</span>)</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Mask out invalid actions (use tf.float32.min for stability)</span></span><br><span class="line">        inf_mask = tf.maximum(tf.log(action_mask), tf.float32.<span class="built_in">min</span>)</span><br><span class="line">        <span class="keyword">return</span> action_logits + inf_mask, state</span><br></pre></td></tr></table></figure><p>参考例子：</p><p>第一步：自定义环境：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">ActionMaskEnv</span>(<span class="title class_ inherited__">RandomEnv</span>):</span><br><span class="line">    <span class="string">&quot;&quot;&quot;A randomly acting environment that publishes an action-mask each step.&quot;&quot;&quot;</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, config</span>):</span><br><span class="line">        <span class="built_in">super</span>().__init__(config)</span><br><span class="line">        <span class="comment"># Masking only works for Discrete actions.</span></span><br><span class="line">        <span class="keyword">assert</span> <span class="built_in">isinstance</span>(<span class="variable language_">self</span>.action_space, Discrete)</span><br><span class="line">        <span class="comment"># Add action_mask to observations.</span></span><br><span class="line">        <span class="variable language_">self</span>.observation_space = <span class="type">Dict</span>(</span><br><span class="line">            &#123;</span><br><span class="line">                <span class="string">&quot;action_mask&quot;</span>: Box(<span class="number">0.0</span>, <span class="number">1.0</span>, shape=(<span class="variable language_">self</span>.action_space.n,)),</span><br><span class="line">                <span class="string">&quot;observations&quot;</span>: <span class="variable language_">self</span>.observation_space,</span><br><span class="line">            &#125;</span><br><span class="line">        )</span><br><span class="line">        <span class="variable language_">self</span>.valid_actions = <span class="literal">None</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">reset</span>(<span class="params">self, *, seed=<span class="literal">None</span>, options=<span class="literal">None</span></span>):</span><br><span class="line">        obs, info = <span class="built_in">super</span>().reset()</span><br><span class="line">        <span class="variable language_">self</span>._fix_action_mask(obs)</span><br><span class="line">        <span class="keyword">return</span> obs, info</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">step</span>(<span class="params">self, action</span>):</span><br><span class="line">        <span class="comment"># Check whether action is valid.</span></span><br><span class="line">        <span class="keyword">if</span> <span class="keyword">not</span> <span class="variable language_">self</span>.valid_actions[action]:</span><br><span class="line">            <span class="keyword">raise</span> ValueError(</span><br><span class="line">                <span class="string">f&quot;Invalid action sent to env! &quot;</span> <span class="string">f&quot;valid_actions=<span class="subst">&#123;self.valid_actions&#125;</span>&quot;</span></span><br><span class="line">            )</span><br><span class="line">        obs, rew, done, truncated, info = <span class="built_in">super</span>().step(action)</span><br><span class="line">        <span class="variable language_">self</span>._fix_action_mask(obs)</span><br><span class="line">        <span class="keyword">return</span> obs, rew, done, truncated, info</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">_fix_action_mask</span>(<span class="params">self, obs</span>):</span><br><span class="line">        <span class="comment"># Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.</span></span><br><span class="line">        <span class="variable language_">self</span>.valid_actions = np.<span class="built_in">round</span>(obs[<span class="string">&quot;action_mask&quot;</span>])</span><br><span class="line">        obs[<span class="string">&quot;action_mask&quot;</span>] = <span class="variable language_">self</span>.valid_actions</span><br></pre></td></tr></table></figure><p>第二步：自定义网络：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">TorchActionMaskModel</span>(TorchModelV2, nn.Module):</span><br><span class="line">    <span class="string">&quot;&quot;&quot;PyTorch version of above ActionMaskingModel.&quot;&quot;&quot;</span></span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params"></span></span><br><span class="line"><span class="params">        self,</span></span><br><span class="line"><span class="params">        obs_space,</span></span><br><span class="line"><span class="params">        action_space,</span></span><br><span class="line"><span class="params">        num_outputs,</span></span><br><span class="line"><span class="params">        model_config,</span></span><br><span class="line"><span class="params">        name,</span></span><br><span class="line"><span class="params">        **kwargs,</span></span><br><span class="line"><span class="params">    </span>):</span><br><span class="line">        orig_space = <span class="built_in">getattr</span>(obs_space, <span class="string">&quot;original_space&quot;</span>, obs_space)</span><br><span class="line">        <span class="keyword">assert</span> (</span><br><span class="line">            <span class="built_in">isinstance</span>(orig_space, <span class="type">Dict</span>)</span><br><span class="line">            <span class="keyword">and</span> <span class="string">&quot;action_mask&quot;</span> <span class="keyword">in</span> orig_space.spaces</span><br><span class="line">            <span class="keyword">and</span> <span class="string">&quot;observations&quot;</span> <span class="keyword">in</span> orig_space.spaces</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">        TorchModelV2.__init__(</span><br><span class="line">            <span class="variable language_">self</span>, obs_space, action_space, num_outputs, model_config, name, **kwargs</span><br><span class="line">        )</span><br><span class="line">        nn.Module.__init__(<span class="variable language_">self</span>)</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.internal_model = TorchFC(</span><br><span class="line">            orig_space[<span class="string">&quot;observations&quot;</span>],</span><br><span class="line">            action_space,</span><br><span class="line">            num_outputs,</span><br><span class="line">            model_config,</span><br><span class="line">            name + <span class="string">&quot;_internal&quot;</span>,</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">        <span class="comment"># disable action masking --&gt; will likely lead to invalid actions</span></span><br><span class="line">        <span class="variable language_">self</span>.no_masking = <span class="literal">False</span></span><br><span class="line">        <span class="keyword">if</span> <span class="string">&quot;no_masking&quot;</span> <span class="keyword">in</span> model_config[<span class="string">&quot;custom_model_config&quot;</span>]:</span><br><span class="line">            <span class="variable language_">self</span>.no_masking = model_config[<span class="string">&quot;custom_model_config&quot;</span>][<span class="string">&quot;no_masking&quot;</span>]</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, input_dict, state, seq_lens</span>):</span><br><span class="line">        <span class="comment"># Extract the available actions tensor from the observation.</span></span><br><span class="line">        action_mask = input_dict[<span class="string">&quot;obs&quot;</span>][<span class="string">&quot;action_mask&quot;</span>]</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Compute the unmasked logits.</span></span><br><span class="line">        logits, _ = <span class="variable language_">self</span>.internal_model(&#123;<span class="string">&quot;obs&quot;</span>: input_dict[<span class="string">&quot;obs&quot;</span>][<span class="string">&quot;observations&quot;</span>]&#125;)</span><br><span class="line"></span><br><span class="line">        <span class="comment"># If action masking is disabled, directly return unmasked logits</span></span><br><span class="line">        <span class="keyword">if</span> <span class="variable language_">self</span>.no_masking:</span><br><span class="line">            <span class="keyword">return</span> logits, state</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Convert action_mask into a [0.0 || -inf]-type mask.</span></span><br><span class="line">        inf_mask = torch.clamp(torch.log(action_mask), <span class="built_in">min</span>=FLOAT_MIN)</span><br><span class="line">        masked_logits = logits + inf_mask</span><br><span class="line"></span><br><span class="line">        <span class="comment"># Return masked logits.</span></span><br><span class="line">        <span class="keyword">return</span> masked_logits, state</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">value_function</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.internal_model.value_function()</span><br></pre></td></tr></table></figure><p>torch.clamp 将输入<code>input</code>张量每个元素的夹紧到区间 [min,max] </p><p>inf_mask趋近于负无穷，<strong>使用inf代替不符合要求的部分，使得softmax选择的动作合理</strong></p><p><img src="https://img-blog.csdnimg.cn/6557a481bce4496abdad60b504e93c54.png" alt="在这里插入图片描述"></p><p>forward中包含一个batchsize内的所有数据的输入，Discrete(100)时，包含0-99的每个action取值的概率。</p>]]></content>
    
    <summary type="html">
    
      &lt;h3 id=&quot;背景&quot;&gt;&lt;a href=&quot;#背景&quot; class=&quot;headerlink&quot; title=&quot;背景&quot;&gt;&lt;/a&gt;背景&lt;/h3&gt;&lt;p&gt;在强化学习解决问题的场景中，动作是体现学习效果最直接的因素，直接影响了智能体下一步的走向和对环境状态的改变。在应用强化学习解决实际问题时，往往不同于gym库中倒立摆那样的情况，而是存在很多的约束。例如，在t时刻智能体可选的动作为1,2,3，但是在t+1时刻只能选1,2.3处于不可用的状态。在这种情况下，就需要借助掩码mask来对智能体的动作进行处理。&lt;/p&gt;
&lt;p&gt;有人会疑问：就不能制定相应的奖励函数使得智能体学习到这种约束吗？这样做是可以的，但是付出的训练代价很大，并且极其容易导致模型发散。因此，在大多数RL落地的场景下，都会使用MASK掩码方法解决动作约束的问题。&lt;/p&gt;
&lt;h3 id=&quot;MASK的方法&quot;&gt;&lt;a href=&quot;#MASK的方法&quot; class=&quot;headerlink&quot; title=&quot;MASK的方法&quot;&gt;&lt;/a&gt;MASK的方法&lt;/h3&gt;&lt;p&gt;Mask的核心就是在输出的动作或者值函数的向量上戴个“面具”，点乘一个{0,1}或者{−∞,1}的行向量，以规范化输出。这样智能体选出的动作就可以进行简单的规范化。&lt;/p&gt;
&lt;h3 id=&quot;MASK的两个关键点&quot;&gt;&lt;a href=&quot;#MASK的两个关键点&quot; class=&quot;headerlink&quot; title=&quot;MASK的两个关键点&quot;&gt;&lt;/a&gt;MASK的两个关键点&lt;/h3&gt;&lt;p&gt;由于强化学习，尤其是深度强化学习，&lt;strong&gt;学的最后还是分布&lt;/strong&gt;，因此只是单单的不让智能体选择不符合规则的动作并不能加速模型的收敛。&lt;/p&gt;
&lt;p&gt;因此，MASK一般加在选择动作前的值函数向量或者其他数据向量上，并且会将MASK后的值传入神经网络训练。&lt;br&gt;两个关键点分别是：&lt;/p&gt;
&lt;blockquote&gt;
&lt;p&gt;1-mask分布 &lt;/p&gt;
&lt;p&gt;2-回传训练&lt;/p&gt;
&lt;/blockquote&gt;
&lt;h3 id=&quot;具体做法&quot;&gt;&lt;a href=&quot;#具体做法&quot; class=&quot;headerlink&quot; title=&quot;具体做法&quot;&gt;&lt;/a&gt;具体做法&lt;/h3&gt;
    
    </summary>
    
    
    
  </entry>
  
  <entry>
    <title>全网VIP视频解析接口-开源-稳定CDN加速 支持腾讯视频、爱奇艺、优酷等几十个平台</title>
    <link href="https://blog.aivgg.com/posts/1257523462.html"/>
    <id>https://blog.aivgg.com/posts/1257523462.html</id>
    <published>2026-06-10T16:18:36.120Z</published>
    <updated>2023-09-14T03:24:25.000Z</updated>
    
    <content type="html"><![CDATA[<p><strong>全网解析 支持站点</strong></p><ul><li>奇艺视频 腾讯 优酷 土豆 芒果 乐视 搜狐 PPTV 华数TV 风行 咪咕 哔哩哔哩 ACfun 暴风 CCTV CNTV 范特西 9i广场舞 搜狐自媒体 M1905视频 看看视频 27盘 虎牙直播 全民直播 战旗直播 人人视频 爆米花 今日头条 天翼视频 糖豆视频 龙珠视频 快手视频<br>一直播 新浪视频 360小视频 熊猫TV 斗鱼TV 花椒直播 网易公开课 音悦台 秒拍网 美拍网 爱拍 凤凰视频 梨视频 微录客 人民微视频 17173视频 优米视频 m3u8 mp4视频 微博视频 YY视频 私有云资源</li></ul><p><strong><a href="https://cdn.yangju.vip/k/?url=%E5%90%8E%E9%9D%A2%E5%8A%A0%E4%B8%8A%E6%92%AD%E6%94%BE%E7%9A%84%E5%9C%B0%E5%9D%80%E5%8D%B3%E5%8F%AF">https://cdn.yangju.vip/k/?url=后面加上播放的地址即可</a></strong></p><figure class="highlight plaintext"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line">https://cdn.yangju.vip/k/?url=</span><br><span class="line"></span><br><span class="line">https://jx.lache.me/cc/?url=</span><br><span class="line"></span><br><span class="line">https://api.653520.top/vip/?url=</span><br><span class="line"></span><br><span class="line">https://jx.ab33.top/vip/?url=</span><br><span class="line"></span><br><span class="line">https://vip.mpos.ren/v/?url=</span><br><span class="line"></span><br><span class="line">https://jx.000180.top/jx/?url=</span><br><span class="line"></span><br><span class="line">https://jx.km58.top/jx/?url=</span><br><span class="line"></span><br><span class="line">https://api.smq1.com/?url=</span><br><span class="line"></span><br><span class="line">https://jx.hezeshi.net/ce/jlexi.php?url=</span><br><span class="line"></span><br><span class="line">https://www.kkflv.com/?url=</span><br><span class="line"></span><br><span class="line">https://jx.618g.com/?url=</span><br><span class="line"></span><br></pre></td></tr></table></figure><p>永久性，重要的是够稳定！而且CDN加速！！解析接口支持:URL模式</p>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;&lt;strong&gt;全网解析 支持站点&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;奇艺视频 腾讯 优酷 土豆 芒果 乐视 搜狐 PPTV 华数TV 风行 咪咕 哔哩哔哩 ACfun 暴风 CCTV CNTV 范特西 9i广场舞 搜狐自媒体 M1905视频 看看视频 27盘 虎牙直播 全民直播 战旗直播 人人视频 爆米花 今日头条 天翼视频 糖豆视频 龙珠视频 快手视频&lt;br&gt;一直播 新浪视频 360小视频 熊猫TV 斗鱼TV 花椒直播 网易公开课 音悦台 秒拍网 美拍网 爱拍 凤凰视频 梨视频 微录客 人民微视频 17173视频 优米视频 m3u8 mp4视频 微博视频 YY视频 私有云资源&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;strong&gt;&lt;a href=&quot;https://cdn.yangju.vip/k/?url=%E5%90%8E%E9%9D%A2%E5%8A%A0%E4%B8%8A%E6%92%AD%E6%94%BE%E7%9A%84%E5%9C%B0%E5%9D%80%E5%8D%B3%E5%8F%AF&quot;&gt;https://cdn.yangju.vip/k/?url=后面加上播放的地址即可&lt;/a&gt;&lt;/strong&gt;&lt;/p&gt;
&lt;figure class=&quot;highlight plaintext&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;https://cdn.yangju.vip/k/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://jx.lache.me/cc/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://api.653520.top/vip/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://jx.ab33.top/vip/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://vip.mpos.ren/v/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://jx.000180.top/jx/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://jx.km58.top/jx/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://api.smq1.com/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://jx.hezeshi.net/ce/jlexi.php?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://www.kkflv.com/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;https://jx.618g.com/?url=&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;永久性，重要的是够稳定！而且CDN加速！！解析接口支持:URL模式&lt;/p&gt;

    
    </summary>
    
    
      <category term="电视剧" scheme="https://blog.aivgg.com/categories/%E7%94%B5%E8%A7%86%E5%89%A7/"/>
    
    
      <category term="电视剧" scheme="https://blog.aivgg.com/tags/%E7%94%B5%E8%A7%86%E5%89%A7/"/>
    
  </entry>
  
  <entry>
    <title>Redis之Stream队列</title>
    <link href="https://blog.aivgg.com/posts/441991174.html"/>
    <id>https://blog.aivgg.com/posts/441991174.html</id>
    <published>2026-06-10T16:18:36.117Z</published>
    <updated>2026-06-10T16:24:04.371Z</updated>
    
    <content type="html"><![CDATA[<p>Redis5.0开始引入了Stream这个数据结构，Stream可以很好地用于消息队列，它支持消息持久化，同时可以记录消费者的位置，即使客户端断开重连，也不会丢失消息。</p><p>使用 XADD 向队列添加消息，如果指定的队列不存在，则创建一个队列，XADD 语法格式：</p><figure class="highlight css"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">XADD key ID field value <span class="selector-attr">[field value ...]</span></span><br><span class="line">key ：队列名称，如果不存在就创建</span><br><span class="line">ID ：消息 id，我们使用 * 表示由 redis 生成，可以自定义，但是要自己保证递增性。</span><br><span class="line">field value ： 记录。</span><br></pre></td></tr></table></figure><figure class="highlight sh"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">redis&gt; XADD mystream * name Sara surname OConnor</span><br><span class="line"><span class="string">&quot;1601372323627-0&quot;</span></span><br><span class="line">redis&gt; XADD mystream * field1 value1 field2 value2 field3 value3</span><br><span class="line"><span class="string">&quot;1601372323627-1&quot;</span></span><br><span class="line">redis&gt; XLEN mystream</span><br><span class="line">(<span class="built_in">integer</span>) 2</span><br></pre></td></tr></table></figure><p>XADD 中的key为队列ID，默认为*自动生成，也可以手动指定</p><figure class="highlight sh"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">XADD mystream 10000000 name Anna  </span><br><span class="line">XADD mystream 10000001 name Bert  </span><br><span class="line">XADD mystream 10000002 name Cathy</span><br></pre></td></tr></table></figure><p>可以使用MAXLEN选项来限制Stream队列流中的最大元素数量。</p><p><img src="https://img-blog.csdnimg.cn/910705baf2a14ff98fe92c536fe85a80.png" alt="在这里插入图片描述"></p><p>读取Stream队列，想从数据流的开头读取多达100个条目：</p><figure class="highlight sh"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">XREAD COUNT 100 STREAMS mystream 0 </span><br></pre></td></tr></table></figure>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;Redis5.0开始引入了Stream这个数据结构，Stream可以很好地用于消息队列，它支持消息持久化，同时可以记录消费者的位置，即使客户端断开重连，也不会丢失消息。&lt;/p&gt;
&lt;p&gt;使用 XADD 向队列添加消息，如果指定的队列不存在，则创建一个队列，XADD 语法格式：&lt;/p&gt;
&lt;figure class=&quot;highlight css&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;XADD key ID field value &lt;span class=&quot;selector-attr&quot;&gt;[field value ...]&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;key ：队列名称，如果不存在就创建&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;ID ：消息 id，我们使用 * 表示由 redis 生成，可以自定义，但是要自己保证递增性。&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;field value ： 记录。&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;figure class=&quot;highlight sh&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;redis&amp;gt; XADD mystream * name Sara surname OConnor&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;string&quot;&gt;&amp;quot;1601372323627-0&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;redis&amp;gt; XADD mystream * field1 value1 field2 value2 field3 value3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;string&quot;&gt;&amp;quot;1601372323627-1&amp;quot;&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;redis&amp;gt; XLEN mystream&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;(&lt;span class=&quot;built_in&quot;&gt;integer&lt;/span&gt;) 2&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;
&lt;p&gt;XADD 中的key为队列ID，默认为*自动生成，也可以手动指定&lt;/p&gt;
&lt;figure class=&quot;highlight sh&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;XADD mystream 10000000 name Anna  &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;XADD mystream 10000001 name Bert  &lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;XADD mystream 10000002 name Cathy&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;
&lt;p&gt;可以使用MAXLEN选项来限制Stream队列流中的最大元素数量。&lt;/p&gt;
&lt;p&gt;&lt;img src=&quot;https://img-blog.csdnimg.cn/910705baf2a14ff98fe92c536fe85a80.png&quot; alt=&quot;在这里插入图片描述&quot;&gt;&lt;/p&gt;
&lt;p&gt;读取Stream队列，想从数据流的开头读取多达100个条目：&lt;/p&gt;
&lt;figure class=&quot;highlight sh&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;XREAD COUNT 100 STREAMS mystream 0 &lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;
    
    </summary>
    
    
      <category term="大数据技术" scheme="https://blog.aivgg.com/categories/%E5%A4%A7%E6%95%B0%E6%8D%AE%E6%8A%80%E6%9C%AF/"/>
    
    
      <category term="数据库技术" scheme="https://blog.aivgg.com/tags/%E6%95%B0%E6%8D%AE%E5%BA%93%E6%8A%80%E6%9C%AF/"/>
    
      <category term="Redis" scheme="https://blog.aivgg.com/tags/Redis/"/>
    
  </entry>
  
  <entry>
    <title>Ray全局变量问题</title>
    <link href="https://blog.aivgg.com/posts/1796588721.html"/>
    <id>https://blog.aivgg.com/posts/1796588721.html</id>
    <published>2023-07-15T09:34:46.000Z</published>
    <updated>2026-06-10T16:24:04.372Z</updated>
    
    <content type="html"><![CDATA[<p>Ray的远程函数功能remote应该被认为是功能性和无副作用的。仅限于远程函数限制我们使用分布式函数式编程，这对于许多用例来说都很好，但实际上有点受限。<br>Ray使用Actor扩展了数据流模型。Actor本质上是一个有状态的worker（或服务）</p><p>假设我们有多个任务在同一个actor上调用方法。例如，我们可能有一个Actor记录来自许多任务的执行信息。我们可以将actor句柄作为参数传递给相关任务来实现这一点。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Actor</span>(<span class="title class_ inherited__">object</span>):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">method</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 创建actor</span></span><br><span class="line">actor = Actor.remote()</span><br><span class="line"></span><br><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">f</span>(<span class="params">actor</span>):</span><br><span class="line">    <span class="comment"># 激活actor的函数</span></span><br><span class="line">    x_id = actor.method.remote()</span><br><span class="line">    <span class="comment"># 真正的阻塞调用返回结果</span></span><br><span class="line">    <span class="keyword">return</span> ray.get(x_id)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 三个任务都会调用同一个actor的方法</span></span><br><span class="line">f.remote(actor)</span><br><span class="line">f.remote(actor)</span><br><span class="line">f.remote(actor)</span><br></pre></td></tr></table></figure><p>参考官方文档<br><a href="https://docs.ray.io/en/latest/ray-core/patterns/global-variables.html#anti-pattern-using-global-variables-to-share-state-between-tasks-and-actors">https://docs.ray.io/en/latest/ray-core/patterns/global-variables.html#anti-pattern-using-global-variables-to-share-state-between-tasks-and-actors</a></p><p>全局变量共享是一种反模式的使用方法，不要使用全局变量与任务和参与者共享状态。相反，将全局变量封装在参与者中，并将参与者句柄传递给其他任务和参与者。</p><p>Ray 驱动程序、任务和 Actor 运行在不同的进程中，因此它们不共享相同的地址空间。这意味着，如果您在一个进程中修改全局变量，则更改不会反映在其他进程中</p><p>解决方案是使用Actor的实例变量来保存全局状态，并将参与者句柄传递到需要修改或访问状态的地方。</p><p>成功的示例：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">GlobalVarActor</span>:</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="variable language_">self</span>.global_var = []</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">set_global_var</span>(<span class="params">self, var</span>):</span><br><span class="line">        <span class="variable language_">self</span>.global_var.append()</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">get_global_var</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> <span class="variable language_">self</span>.global_var</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Actor</span>:</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, global_var_actor</span>):</span><br><span class="line">        <span class="variable language_">self</span>.global_var_actor = global_var_actor</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">f</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> ray.get(<span class="variable language_">self</span>.global_var_actor.get_global_var.remote()) + <span class="number">3</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">global_var_actor = GlobalVarActor.remote()</span><br><span class="line">actor = Actor.remote(global_var_actor)</span><br><span class="line">ray.get(global_var_actor.set_global_var.remote(<span class="number">4</span>))</span><br><span class="line"><span class="comment"># This returns 7 correctly.</span></span><br><span class="line"><span class="keyword">assert</span> ray.get(actor.f.remote()) == <span class="number">7</span></span><br></pre></td></tr></table></figure><p>失败的示例：</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> ray</span><br><span class="line"></span><br><span class="line">ray.init()</span><br><span class="line"></span><br><span class="line">global_var = <span class="number">3</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@ray.remote</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">Actor</span>:</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">f</span>(<span class="params">self</span>):</span><br><span class="line">        <span class="keyword">return</span> global_var + <span class="number">3</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">actor = Actor.remote()</span><br><span class="line">global_var = <span class="number">4</span></span><br><span class="line"><span class="comment"># This returns 6, not 7. It is because the value change of global_var</span></span><br><span class="line"><span class="comment"># inside a driver is not reflected to the actor</span></span><br><span class="line"><span class="comment"># because they are running in different processes.</span></span><br><span class="line"><span class="keyword">assert</span> ray.get(actor.f.remote()) == <span class="number">6</span></span><br><span class="line"></span><br></pre></td></tr></table></figure><p>返回值为6，是因为Actor在定义时<br>角色方法会运行在一个*有状态（stateful）的工作进程上。实例化一个Actor时，会创建一个全新的worker，并且在该新的Actor上执行所有方法。所以当actor &#x3D; Actor.remote()时，进程中的global_var值为3，即使后续修改了仍不会生效。</p><p>解决办法就是将Actor句柄作为参数传递给相关的任务即可实现全局数据共享。</p><p>保证Actor状态的一致性，对同一个Actor的方法调用是串行执行的。多个Actor，是并行地执行Actor的方法的</p>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;Ray的远程函数功能remote应该被认为是功能性和无副作用的。仅限于远程函数限制我们使用分布式函数式编程，这对于许多用例来说都很好，但实际上有点受限。&lt;br&gt;Ray使用Actor扩展了数据流模型。Actor本质上是一个有状态的worker（或服务）&lt;/p&gt;
&lt;p&gt;假设我们有多个任务在同一个actor上调用方法。例如，我们可能有一个Actor记录来自许多任务的执行信息。我们可以将actor句柄作为参数传递给相关任务来实现这一点。&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;meta&quot;&gt;@ray.remote&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;Actor&lt;/span&gt;(&lt;span class=&quot;title class_ inherited__&quot;&gt;object&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;method&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;pass&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 创建actor&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;actor = Actor.remote()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;meta&quot;&gt;@ray.remote&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;f&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;actor&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;comment&quot;&gt;# 激活actor的函数&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    x_id = actor.method.remote()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;comment&quot;&gt;# 真正的阻塞调用返回结果&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; ray.get(x_id)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# 三个任务都会调用同一个actor的方法&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;f.remote(actor)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;f.remote(actor)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;f.remote(actor)&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;
&lt;p&gt;参考官方文档&lt;br&gt;&lt;a href=&quot;https://docs.ray.io/en/latest/ray-core/patterns/global-variables.html#anti-pattern-using-global-variables-to-share-state-between-tasks-and-actors&quot;&gt;https://docs.ray.io/en/latest/ray-core/patterns/global-variables.html#anti-pattern-using-global-variables-to-share-state-between-tasks-and-actors&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;全局变量共享是一种反模式的使用方法，不要使用全局变量与任务和参与者共享状态。相反，将全局变量封装在参与者中，并将参与者句柄传递给其他任务和参与者。&lt;/p&gt;
&lt;p&gt;Ray 驱动程序、任务和 Actor 运行在不同的进程中，因此它们不共享相同的地址空间。这意味着，如果您在一个进程中修改全局变量，则更改不会反映在其他进程中&lt;/p&gt;
&lt;p&gt;解决方案是使用Actor的实例变量来保存全局状态，并将参与者句柄传递到需要修改或访问状态的地方。&lt;/p&gt;
&lt;p&gt;成功的示例：&lt;/p&gt;
&lt;figure class=&quot;highlight python&quot;&gt;&lt;table&gt;&lt;tr&gt;&lt;td class=&quot;gutter&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;1&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;2&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;3&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;4&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;5&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;6&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;7&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;8&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;9&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;10&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;11&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;12&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;13&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;14&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;15&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;16&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;17&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;18&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;19&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;20&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;21&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;22&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;23&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;24&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;25&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;26&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;td class=&quot;code&quot;&gt;&lt;pre&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;meta&quot;&gt;@ray.remote&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;GlobalVarActor&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.global_var = []&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;set_global_var&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, var&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.global_var.append()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;get_global_var&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.global_var&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;meta&quot;&gt;@ray.remote&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;title class_&quot;&gt;Actor&lt;/span&gt;:&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;__init__&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self, global_var_actor&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.global_var_actor = global_var_actor&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;    &lt;span class=&quot;keyword&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;title function_&quot;&gt;f&lt;/span&gt;(&lt;span class=&quot;params&quot;&gt;self&lt;/span&gt;):&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;        &lt;span class=&quot;keyword&quot;&gt;return&lt;/span&gt; ray.get(&lt;span class=&quot;variable language_&quot;&gt;self&lt;/span&gt;.global_var_actor.get_global_var.remote()) + &lt;span class=&quot;number&quot;&gt;3&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;global_var_actor = GlobalVarActor.remote()&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;actor = Actor.remote(global_var_actor)&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;ray.get(global_var_actor.set_global_var.remote(&lt;span class=&quot;number&quot;&gt;4&lt;/span&gt;))&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;comment&quot;&gt;# This returns 7 correctly.&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;span class=&quot;line&quot;&gt;&lt;span class=&quot;keyword&quot;&gt;assert&lt;/span&gt; ray.get(actor.f.remote()) == &lt;span class=&quot;number&quot;&gt;7&lt;/span&gt;&lt;/span&gt;&lt;br&gt;&lt;/pre&gt;&lt;/td&gt;&lt;/tr&gt;&lt;/table&gt;&lt;/figure&gt;

&lt;p&gt;失败的示例：&lt;/p&gt;
    
    </summary>
    
    
      <category term="强化学习Ray" scheme="https://blog.aivgg.com/categories/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0Ray/"/>
    
    
      <category term="Ray" scheme="https://blog.aivgg.com/tags/Ray/"/>
    
      <category term="强化学习" scheme="https://blog.aivgg.com/tags/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/"/>
    
  </entry>
  
  <entry>
    <title>pyinstaller 打包</title>
    <link href="https://blog.aivgg.com/posts/2243094175.html"/>
    <id>https://blog.aivgg.com/posts/2243094175.html</id>
    <published>2023-02-22T09:34:46.000Z</published>
    <updated>2023-09-14T03:24:25.000Z</updated>
    
    <content type="html"><![CDATA[<p><a href="https://blog.csdn.net/qq_35722703/article/details/121117169">https://blog.csdn.net/qq_35722703/article/details/121117169</a></p>]]></content>
    
    <summary type="html">
    
      &lt;p&gt;&lt;a href=&quot;https://blog.csdn.net/qq_35722703/article/details/121117169&quot;&gt;https://blog.csdn.net/qq_35722703/article/details/121117169&lt;/a&gt;&lt;/p&gt;

    
    </summary>
    
    
      <category term="Python基础" scheme="https://blog.aivgg.com/categories/Python%E5%9F%BA%E7%A1%80/"/>
    
    
      <category term="Python" scheme="https://blog.aivgg.com/tags/Python/"/>
    
      <category term="打包" scheme="https://blog.aivgg.com/tags/%E6%89%93%E5%8C%85/"/>
    
  </entry>
  
</feed>
