sparse_caption.pruning package

Submodules

sparse_caption.pruning.masked_layer module

Created on 23 Sep 2020 17:36:39 @author: jiahuei

class sparse_caption.pruning.masked_layer.MaskMixin

Bases: object

static assert_in_kwargs(key, kwargs)
get_masked_weight(weight_name: str)
mask_init_value: float
mask_trainable: bool
mask_type: str
reset_masks() None
setup_masks(parameters: Union[str, List[str], Tuple[str, ...]], mask_type: str, mask_init_value: float = 1.0, bypass_sigmoid_grad: bool = False) None
training: bool
class sparse_caption.pruning.masked_layer.MaskedEmbedding(num_embeddings: int, embedding_dim: int, mask_type: str, mask_init_value: float, bypass_sigmoid_grad: bool = False, **kwargs)

Bases: sparse_caption.pruning.masked_layer.MaskMixin, torch.nn.modules.sparse.Embedding

A simple lookup table that stores embeddings of a fixed dictionary and size.

This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.

forward(input: torch.Tensor) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_pretrained(*args, **kwargs)

Creates Embedding instance from given 2-dimensional FloatTensor.

mask_init_value: float
mask_trainable: bool
mask_type: str
training: bool
class sparse_caption.pruning.masked_layer.MaskedLSTMCell(input_size: int, hidden_size: int, mask_type: str, mask_init_value: float, bypass_sigmoid_grad: bool = False, **kwargs)

Bases: sparse_caption.pruning.masked_layer.MaskMixin, torch.nn.modules.rnn.LSTMCell

A masked long short-term memory (LSTM) cell.

self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size)) self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size))

forward(input: torch.Tensor, hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) Tuple[torch.Tensor, torch.Tensor]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

mask_init_value: float
mask_trainable: bool
mask_type: str
training: bool
class sparse_caption.pruning.masked_layer.MaskedLSTMCellCheckpoint(input_size: int, hidden_size: int, mask_type: str, mask_init_value: float, bypass_sigmoid_grad: bool = False, **kwargs)

Bases: sparse_caption.pruning.masked_layer.MaskMixin, torch.nn.modules.rnn.LSTMCell

A masked long short-term memory (LSTM) cell.

self.weight_ih = Parameter(torch.Tensor(num_chunks * hidden_size, input_size)) self.weight_hh = Parameter(torch.Tensor(num_chunks * hidden_size, hidden_size))

forward(input: torch.Tensor, hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) Tuple[torch.Tensor, torch.Tensor]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

mask_init_value: float
mask_trainable: bool
mask_type: str
training: bool
class sparse_caption.pruning.masked_layer.MaskedLinear(in_features: int, out_features: int, mask_type: str, mask_init_value: float, bypass_sigmoid_grad: bool = False, **kwargs)

Bases: sparse_caption.pruning.masked_layer.MaskMixin, torch.nn.modules.linear.Linear

Applies a linear transformation to the incoming data: \(y = xA^T + b\)

forward(input: torch.Tensor) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

mask_init_value: float
mask_trainable: bool
mask_type: str
training: bool

sparse_caption.pruning.prune module

Created on 25 Sep 2020 19:25:43 @author: jiahuei

class sparse_caption.pruning.prune.PruningMixin(*, mask_type, mask_freeze_scope='', **kwargs)

Bases: object

Mixin class to be used together with torch.nn.Module

property active_mask_avg
property active_mask_sparsities
active_pruned_weights(named=True)
active_pruning_masks(named=True)
static add_argparse_args(parser: Union[argparse._ArgumentGroup, argparse.ArgumentParser])
property all_mask_avg
property all_mask_sparsities
all_pruned_weights(named=True)
all_pruning_masks(named=True)
property all_weight_sparsities
all_weights(named=True)
static calculate_sparsities(tensor_list, count_nnz_fn)
static compute_mask(criterion, sparsity_target)
compute_sparsity_loss(sparsity_target: float, weight: float, current_step: int, max_step: int)

Loss for controlling sparsity of Supermasks. :param sparsity_target: Desired sparsity rate. :param weight: :param current_step: :param max_step:

Returns

Scalar loss value.

load_sparse_state_dict(sparse_state_dict: Dict[str, torch.Tensor], strict: bool = True)
load_state_dict: Callable
named_parameters: Callable
prune_weights()
sparsity_check(warning_threshold: float = 0.999)
state_dict: Callable
state_dict_dense(destination=None, prefix='', keep_vars=False, discard_pruning_mask=False, prune_weights=True, binarize_supermasks=False)
state_dict_sparse(destination=None, prefix='', keep_vars=False, discard_pruning_mask=True, prune_weights=True, binarize_supermasks=False)
property total_mask_params
property total_weight_params
trainable_pruning_masks(named=True)
update_masks_gradual(sparsity_target: float, current_step: int, start_step: int, prune_steps: int, initial_sparsity: float = 0.0, prune_frequency: int = 1000)

Get current sparsity level for gradual pruning. https://arxiv.org/abs/1710.01878 https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/contrib/model_pruning

Parameters
  • sparsity_target – Final sparsity

  • current_step – Current global step

  • start_step – When to start pruning

  • prune_steps – Number of pruning steps to take

  • initial_sparsity – Starting sparsity

  • prune_frequency – Number of training steps per pruning step

Returns

True if pruning masks are updated at this step, False otherwise.

update_masks_once(sparsity_target: float)
Parameters

sparsity_target

Returns

True if pruning masks are successfully updated.

sparse_caption.pruning.sampler module

Created on 24 Sep 2020 19:53:25 @author: jiahuei

class sparse_caption.pruning.sampler.BernoulliSample(*args, **kwargs)

Bases: torch.autograd.function.Function

static backward(ctx, grad_output)

Defines a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, probs)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store arbitrary data that can be then retrieved during the backward pass.

class sparse_caption.pruning.sampler.BernoulliSampleSigmoid(*args, **kwargs)

Bases: sparse_caption.pruning.sampler.BernoulliSample

static forward(ctx, logits)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store arbitrary data that can be then retrieved during the backward pass.

class sparse_caption.pruning.sampler.Round(*args, **kwargs)

Bases: torch.autograd.function.Function

static backward(ctx, grad_output)

Defines a formula for differentiating the operation with backward mode automatic differentiation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by as many outputs as the forward() returned (None will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to forward(). Each argument is the gradient w.r.t the given output, and each returned value should be the gradient w.r.t. the corresponding input. If an input is not a Tensor or is a Tensor not requiring grads, you can just pass None as a gradient for that input.

The context can be used to retrieve tensors saved during the forward pass. It also has an attribute ctx.needs_input_grad as a tuple of booleans representing whether each input needs gradient. E.g., backward() will have ctx.needs_input_grad[0] = True if the first input to forward() needs gradient computated w.r.t. the output.

static forward(ctx, probs)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store arbitrary data that can be then retrieved during the backward pass.

class sparse_caption.pruning.sampler.RoundSigmoid(*args, **kwargs)

Bases: sparse_caption.pruning.sampler.Round

static forward(ctx, logits)

Performs the operation.

This function is to be overridden by all subclasses.

It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).

The context can be used to store arbitrary data that can be then retrieved during the backward pass.

sparse_caption.pruning.sampler.bernoulli_sample_sigmoid(logits, bypass_sigmoid_grad=False)

Performs stochastic Bernoulli sampling. Accepts raw logits instead of normalised probabilities.

sparse_caption.pruning.sampler.rounding_sigmoid(logits, bypass_sigmoid_grad=False)

Performs deterministic binarisation with adjustable threshold. Accepts raw logits instead of normalised probabilities.

Module contents

Created on 14 Jun 2019 15:37:45

@author: jiahuei