Interface configuration model training based on data as tf.Tensors.

interface ModelFitArgs {
    batchSize?: number;
    callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[];
    classWeight?: ClassWeight | ClassWeight[] | ClassWeightMap;
    epochs?: number;
    initialEpoch?: number;
    sampleWeight?: Tensor<Rank>;
    shuffle?: boolean;
    stepsPerEpoch?: number;
    validationData?: [Tensor<Rank> | Tensor<Rank>[], Tensor<Rank> | Tensor<Rank>[]] | [Tensor<Rank> | Tensor<Rank>[], Tensor<Rank> | Tensor<Rank>[], Tensor<Rank> | Tensor<Rank>[]];
    validationSplit?: number;
    validationSteps?: number;
    verbose?: 2 | ModelLoggingVerbosity;
    yieldEvery?: YieldEveryOptions;
}

Properties

batchSize?: number

Number of samples per gradient update. If unspecified, it will default to 32.

callbacks?: BaseCallback[] | CustomCallbackArgs | CustomCallbackArgs[]

List of callbacks to be called during training. Can have one or more of the following callbacks:

  • onTrainBegin(logs): called when training starts.
  • onTrainEnd(logs): called when training ends.
  • onEpochBegin(epoch, logs): called at the start of every epoch.
  • onEpochEnd(epoch, logs): called at the end of every epoch.
  • onBatchBegin(batch, logs): called at the start of every batch.
  • onBatchEnd(batch, logs): called at the end of every batch.
  • onYield(epoch, batch, logs): called every yieldEvery milliseconds with the current epoch, batch and logs. The logs are the same as in onBatchEnd(). Note that onYield can skip batches or epochs. See also docs for yieldEvery below.

Optional object mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. This can be useful to tell the model to "pay more attention" to samples from an under-represented class.

If the model has multiple outputs, a class weight can be specified for each of the outputs by setting this field an array of weight object or an object that maps model output names (e.g., model.outputNames[0]) to weight objects.

epochs?: number

Integer number of times to iterate over the training data arrays.

initialEpoch?: number

Epoch at which to start training (useful for resuming a previous training run). When this is used, epochs is the index of the "final epoch". The model is not trained for a number of iterations given by epochs, but merely until the epoch of index epochs is reached.

sampleWeight?: Tensor<Rank>

Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequenceLength), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sampleWeightMode="temporal" in compile().

shuffle?: boolean

Whether to shuffle the training data before each epoch. Has no effect when stepsPerEpoch is not null.

stepsPerEpoch?: number

Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. When training with Input Tensors such as TensorFlow data tensors, the default null is equal to the number of unique samples in your dataset divided by the batch size, or 1 if that cannot be determined.

validationData?: [Tensor<Rank> | Tensor<Rank>[], Tensor<Rank> | Tensor<Rank>[]] | [Tensor<Rank> | Tensor<Rank>[], Tensor<Rank> | Tensor<Rank>[], Tensor<Rank> | Tensor<Rank>[]]

Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. This could be a tuple [xVal, yVal] or a tuple [xVal, yVal, valSampleWeights]. The model will not be trained on this data. validationData will override validationSplit.

validationSplit?: number

Float between 0 and 1: fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. The validation data is selected from the last samples in the x and y data provided, before shuffling.

validationSteps?: number

Only relevant if stepsPerEpoch is specified. Total number of steps (batches of samples) to validate before stopping.

verbose?: 2 | ModelLoggingVerbosity

Verbosity level.

Expected to be 0, 1, or 2. Default: 1.

0 - No printed message during fit() call. 1 - In Node.js (tfjs-node), prints the progress bar, together with real-time updates of loss and metric values and training speed. In the browser: no action. This is the default. 2 - Not implemented yet.

yieldEvery?: YieldEveryOptions

Configures the frequency of yielding the main thread to other tasks.

In the browser environment, yielding the main thread can improve the responsiveness of the page during training. In the Node.js environment, it can ensure tasks queued in the event loop can be handled in a timely manner.

The value can be one of the following:

  • 'auto': The yielding happens at a certain frame rate (currently set at 125ms). This is the default.
  • 'batch': yield every batch.
  • 'epoch': yield every epoch.
  • any number: yield every number milliseconds.
  • 'never': never yield. (yielding can still happen through await nextFrame() calls in custom callbacks.)

Generated using TypeDoc