---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[62], line 8
5 svi = SVI(model.model, guide, optimizer, loss=TraceEnum_ELBO())
7 for step in range(500):
----> 8 loss = svi.step(Y_tensor)
9 if step % 50 == 0:
10 print(f"Step {step}: loss = {loss:.2f}")
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\infer\svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\infer\traceenum_elbo.py:451, in TraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
443 """
444 :returns: an estimate of the ELBO
445 :rtype: float
(...) 448 Performs backward on the ELBO of each particle.
449 """
450 elbo = 0.0
--> 451 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
452 elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
453 if is_identically_zero(elbo_particle):
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\infer\traceenum_elbo.py:374, in TraceEnum_ELBO._get_traces(self, model, guide, args, kwargs)
372 raise NotImplementedError("TraceEnum_ELBO does not support GuideMessenger")
373 if self.max_plate_nesting == float("inf"):
--> 374 self._guess_max_plate_nesting(model, guide, args, kwargs)
375 if self.vectorize_particles:
376 guide = self._vectorized_num_particles(guide)
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\infer\elbo.py:153, in ELBO._guess_max_plate_nesting(self, model, guide, args, kwargs)
151 with poutine.block():
152 guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
--> 153 model_trace = poutine.trace(
154 poutine.replay(model, trace=guide_trace)
155 ).get_trace(*args, **kwargs)
156 guide_trace = prune_subsample_sites(guide_trace)
157 model_trace = prune_subsample_sites(model_trace)
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\trace_messenger.py:216, in TraceHandler.get_trace(self, *args, **kwargs)
208 def get_trace(self, *args, **kwargs) -> Trace:
209 """
210 :returns: data structure
211 :rtype: pyro.poutine.Trace
(...) 214 Calls this poutine and returns its trace instead of the function's return value.
215 """
--> 216 self(*args, **kwargs)
217 return self.msngr.get_trace()
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\trace_messenger.py:191, in TraceHandler.__call__(self, *args, **kwargs)
187 self.msngr.trace.add_node(
188 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
189 )
190 try:
--> 191 ret = self.fn(*args, **kwargs)
192 except (ValueError, RuntimeError) as e:
193 exc_type, exc_value, traceback = sys.exc_info()
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\messenger.py:32, in _context_wrap(context, fn, *args, **kwargs)
25 def _context_wrap(
26 context: "Messenger",
27 fn: Callable,
28 *args: Any,
29 **kwargs: Any,
30 ) -> Any:
31 with context:
---> 32 return fn(*args, **kwargs)
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\messenger.py:32, in _context_wrap(context, fn, *args, **kwargs)
25 def _context_wrap(
26 context: "Messenger",
27 fn: Callable,
28 *args: Any,
29 **kwargs: Any,
30 ) -> Any:
31 with context:
---> 32 return fn(*args, **kwargs)
Cell In[61], line 22, in SimpleDiscreteHMM.model(self, Y)
19 num_states = self.num_states
21 with pyro.plate("donors", n_donors, dim=-2):
---> 22 z_prev = pyro.sample("z_0", dist.Categorical(pyro.param("start_probs")))
24 for t in range(T):
25 z_t = pyro.sample(f"z_{t}", dist.Categorical(pyro.param("trans_probs")[z_prev]))
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\primitives.py:88, in param(name, init_tensor, constraint, event_dim)
86 # Note effectful(-) requires the double passing of name below.
87 args = (name,) if init_tensor is None else (name, init_tensor)
---> 88 value = _param(*args, constraint=constraint, event_dim=event_dim, name=name)
89 assert value is not None # type narrowing guaranteed by _param
90 return value
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\runtime.py:461, in effectful.<locals>._fn(name, infer, obs, *args, **kwargs)
444 msg = Message(
445 type=type,
446 name=name,
(...) 458 infer=infer if infer is not None else {},
459 )
460 # apply the stack and return its return value
--> 461 apply_stack(msg)
462 if TYPE_CHECKING:
463 assert msg["value"] is not None
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\runtime.py:383, in apply_stack(initial_msg)
380 if msg["stop"]:
381 break
--> 383 default_process_message(msg)
385 for frame in stack[-pointer:]:
386 frame._postprocess_message(msg)
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\poutine\runtime.py:345, in default_process_message(msg)
342 msg["done"] = True
343 return
--> 345 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
347 # after fn has been called, update msg to prevent it from being called again.
348 msg["done"] = True
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\params\param_store.py:249, in ParamStoreDict.get_param(self, name, init_tensor, constraint, event_dim)
233 """
234 Get parameter from its name. If it does not yet exist in the
235 ParamStore, it will be created and stored.
(...) 246 :rtype: torch.Tensor
247 """
248 if init_tensor is None:
--> 249 return self[name]
250 else:
251 return self.setdefault(name, init_tensor, constraint)
File c:\Users\erik4\AppData\Local\Programs\Python\Python313\Lib\site-packages\pyro\params\param_store.py:129, in ParamStoreDict.__getitem__(self, name)
125 def __getitem__(self, name: str) -> torch.Tensor:
126 """
127 Get the *constrained* value of a named parameter.
128 """
--> 129 unconstrained_value = self._params[name]
131 # compute the constrained value
132 constraint = self._constraints[name]
KeyError: 'start_probs'