The disclosure generally relates to circuits for accelerating softmax and log softmax computations.
The softmax function applied to a vector X of n+1 real values (indexed 0 . . . n), normalizes the values into a probability distribution consisting of n+1 probabilities proportional to the exponentials of the input values. Some vector elements may be negative or greater than one. Each element in the tensor resulting from application of the softmax function will be in the interval (0, 1), and the elements sum to 1. The softmax function on element xt (subscripts alternatively denoted “x_t” herein) of an input tensor can be stated as:
softmax(xt)=ex_t/SUM(y=0 . . . N,(ex_y))
In order to avoid overflow, the element having the greatest value (xmax) in the input tensor is subtracted from the exponents in the calculation.
softmax(xt)=(ex_t−x_max)/SUM(y=0 . . . N(ex_y−x_max)))
In applications involving large tensors, traversing all elements of the tensor to find the maximum value can consume a considerable amount of time, and until the maximum of the tensor elements is found, the exponential function calculation will be blocked.
For many neural networks, such as convolutional neural networks, recurrent neural networks, etc., softmax is applied only in the final layer. Transformer neural networks are now providing encouraging results in applications previously dominated by CNNs. Notably, each attention layer of a transformer neural network can have softmax and dropout operations in addition to the standard matrix multiplication-based in fully-connected layers. Transformer networks are large and can have hundreds of millions to hundreds of billions of parameters, and the softmax function can have a significant negative impact on performance.
A disclosed method includes transforming in parallel, elements of each group of a plurality of groups of elements of a tensor X into respective power-of-two elements by a processor circuit (202). The respective power-of-two element from element xt of the tensor is pt, pt=(xt*log2e), and pt has an integer part and a fraction part. The method includes determining respective group-level biases for the groups by a comparison circuit (204), wherein the group-level bias of groupm is dm, and dm is an integer part of a maximum of the power-of-two elements of groupm. The method further includes determining a greatest one of the respective group-level biases by the comparison circuit (206) to be a tensor-level bias, dmax.
A disclosed circuit arrangement includes a processor circuit configured to transform in parallel, elements of each group of a plurality of groups of elements of a tensor X into respective power-of-two elements. The respective power-of-two element from element xt of the tensor is pt, pt=(xt*log2e), and pt has an integer part and a fraction part. A first comparison circuit is configured to determine respective group-level biases for the groups. The group-level bias of groupm is dm, and dm is an integer part of a maximum of the power-of-two elements of groupm. A second comparison circuit is configured to determine a greatest one of the respective group-level biases to be a tensor-level bias, dmax.
Other features will be recognized from consideration of the Detailed Description and Claims, which follow.
Various aspects and features of the methods and circuits will become apparent upon review of the following detailed description and upon reference to the drawings in which:
In the following description, numerous specific details are set forth to describe specific examples presented herein. It should be apparent, however, to one skilled in the art, that one or more other examples and/or variations of these examples may be practiced without all the specific details given below. In other instances, well known features have not been described in detail so as not to obscure the description of the examples herein. For ease of illustration, the same reference numerals may be used in different diagrams to refer to the same elements or additional instances of the same element.
The disclosed approaches provide methods and circuitry that addresses the aforementioned issues. This methods and circuits are useful in neural network inference and training. According to the disclosed approaches, the exponential functions of softmax are transformed into 2x form. The transformation is explained as follows.
ab=elna
For a=2 and b*ln(a)=x,
A bias, dmax, is computed to prevent overflow and underflow and to align terms for summing as:
d
max
=[x
max*log2e]
where [⋅] is a floor operation. The softmax function can be restated as:
where xt is element t of tensor X, and the summations are of all xt in X.
The term, 2x
2x
where xt_k is the integer part, xt_j is the fractional part of (xt*log2e), and xt_j is in the interval [0, 1). To calculate the softmax function according to the disclosed approaches, three components are calculated: xt_k, 2(x_t)_j, and dmax. The term, xt_j, is in the interval [0, 1), and 2(x_t)_j can be approximated by polynomial fitting with acceptable precision and degree. The value, 2(x_t)_j, is a floating-point number and can have an 8-bit exponent, for example. After polynomial fitting, the exponent is modified by xt_k−dmax.
The disclosed methods and circuitry significantly reduce the time expended in computing dmax by dividing an input tensor is divided into several groups, converting tensor elements into power-of-two values, determining group-level biases, adjusting the power-of-two values according to the group-level biases, and summing the adjusted values of the groups. The number of tensor elements in each group is set according to the desired level of computational parallelism. In the exemplary methods and circuitry, each group has 8 tensor elements, though different implementations can have more or fewer tensor elements depending on hardware capabilities.
The groups of elements can be input one group at a time, and a processor circuit is configured to multiply the elements of the group by log2e in parallel (xt*log2e for t=0 . . . 7). The purpose multiplying xt*log2e is to transform ex_t to the form 2y (“power-of-two form”), where y=xt*log2e, per the derivations above. The products produced from each group m are used to determine the group-level bias, dm. The dm of group m is the integer part of the greatest one of the products of the group ([max (xt*log2e for t=0 . . . 7))]). A tensor-level bias, dmax, is determined by finding the greatest of the dm values as the groups are successively processed.
The dm along with the xt_k and 2(x_t)_j, values are used to adjust the computed products and prevent overflow and underflow relative to the group. The 2(x_t)_j, values for a group are determined by polynomial fitting, and the power-of-two values are adjusted by xt_k−dm+(the exponent bits of 2(x_t)_j). The group-biased power-of-two values, ex_t*2−d_m, are stored in association with the group-level dm as each group is computed in buffer2, which is shown by dashed block 104. Each ex_t*2−d_m is a floating point value having an exponent equal to xt_k−dm+(the exponent bits of 2(x_t)_j), and a mantissa equal to the mantissa of 2(x_t)_j. Buffer2 can be an on-chip or off-chip RAM (relative to computational circuitry), and the group-level dm values and associated group-biased power-of-two values can be input by a streaming or direct memory access (DMA) interface.
The adjusted power-of-two values are accumulated into a group-level sum (2−d_m summ=SUMgroup_m=sum(ex_t*2−d_m) for all t in groupm) as the adjusted power-of-two values are computed. The group-level sums are accumulated into a tensor-level sum as each group is accumulated. The group-level dm is compared to the group level d′max once dm is determined, and the group-level sum 2−d_m summ is aligned with the current sum (2−d′_max sum′) according to the current value of d′max. Once aligned, the aligned group-level sum 2−d_m summ is added to the current sum (2−d′_max sum′) to produce a new current sum.
The tensor-level dmax and tensor level sum, 2−d_max sum, (which is 2−d
Once dmax has been determined, the group-biased power-of-two values (ex_t*2−d_m) are tensor-wise adjusted based on the tensor-level dmax value. The tensor-wise biases for elements in groupm are made by retrieving from buffer2, dm and exponents of the associated group-wise-adjusted power-of-two values ex_t*2−d_m. The exponents of ex_t*2−d_m in groupm are added to (dm−dmax) to generate the exponents of the ex_t* 2−d_max values, which are illustrated in the column 106 of blocks. The mantissas of the ex_t*2−d_max values are the same as the mantissas of the corresponding values from buffer2. Though not shown in
A group of p+1 tensor elements (xt, t=0 . . . p) is read from buffer1 102 and input in parallel to processor circuitry 202. Processor circuit 202 computes products (“power-of-two elements”) of xt*log2e for t=0 . . . p in parallel. The p+1 power-of-two elements are provided on parallel signal lines to circuit 204, which compares values of the p+1 power-of-two elements and extracts and provides the integer portion of the greatest one of the values as dm. The compare-and-select circuit 206 compares the dm value from circuit 204 to the current dmax value in register 208 and selects the greater of the two values to update the contents of the register.
The power-of-two elements computed by processor circuits 202 are floating point values, and the integer portions (groupm xt_k) and fraction portions (groupm xt_j) of the values are determined from the mantissas and exponents. The integer portions are provided to the subtraction circuits 210, and the fraction portions are provided to the processor circuitry 212, which can be a vector processor that performs multiply-and-accumulate (“MAC”) operations in parallel.
The subtraction circuits 210 compute in parallel the differences between the integer portions and the group-level bias, dm (xt_k−dm for t=0 . . . p). The processor circuitry 212 computes in parallel 2(x_t)_j for t=0 . . . p by polynomial fitting of the fraction portions, xt_j. The tensor elements of the next group (groupm+1) can be input to the processor circuitry 202 for computing the power-of-two elements while circuit 204 determines the group-level bias dm, the subtraction circuits 210 compute the differences (xt_k−dm for t=0 . . . p), and the processor circuitry 212 computes 2(x_t)_j for t=0 . . . p for groupm.
The differences and exponents of the 2(x_t)_j values are input to adder circuits 214 that compute in parallel the exponents of the group-biased power-of-two elements. Each ex_t*2−d_m is a floating point value having an exponent equal to xt_k−dm+(the exponent bits of 2(x_t)_j), and a mantissa equal to the mantissa of 2(x_t)_j. The group-biased power-of-two values for the group are stored in buffer2 104 in association with the group-level bias dm.
The group-biased power-of-two values for the group are input to summing circuit 216, which sums the group-biased power-of-two values into a group-level sum (SUMm=sum(ex_t*2−d_m) for all xt in group m).
The update circuit 218 accumulates the group-level sums as each group-level sum is provided by summing circuit 216. The update circuit 218 inputs the group-level sum from summing circuit 216, the current greatest bias value, dmax from register 208, and the current accumulated SUM from register 220. The update circuit aligns the group-level sum and the current accumulated SUM according to dmax and produces a new SUM that is stored in register 220.
Once all groups of tensor elements of a tensor (e.g., group0 . . . groupi of a tensor having i+1 groups) have been processed and a final tensor-level sum has been computed, control circuit 222 can activate the final softmax circuitry 224. The final softmax circuitry generates final softmax values group-by-group, with the p+1 softmax values generated in parallel. The final softmax circuitry inputs the tensor-level bias, dmax, from register 208, the final tensor level SUM from register 220 (SUM=2−d_max*sum(2x_t*log_2(e))), and reads the group-biased power-of-two elements of groupm and the associated group-level bias dm from buffer2 104.
The subtractor circuit 226 of the final softmax circuitry determines the difference between the group-level bias, dm, and the tensor-level bias dmax (dm−dmax). The difference from subtractor circuit 226 and the exponents of the group-biased power-of-two elements (exp(ex_t*2−d_m)) are input to adder circuits 228. The adder circuits 228 compute in parallel sums of the difference and the exponents of the ex_t*2−d_m terms from buffer 2. The sums from exponent adders 228 are exponents that are paired with the corresponding mantissas of the ex_t*2−d_m terms from buffer2 to provide the tensor-baised terms, “xt_dmax” as divisors to the divider circuit 230. The exponent of xt_dmax. is exp(xt_dmax)=(dm−dmax)+exp(ex_t*2−d_m), and the mantissa of xt_dmax. is man(xt_dmax)=man(ex_t*2−d_m).
Divider circuitry 230, which can be a vector division circuit, computes in parallel the final softmax values (xt_dmax/SUM for t=0 . . . p)).
The control circuit 222 controls activation of the final softmax circuit 224 and final log_softmax circuit 232. The final softmax circuit and the final log_softmax circuit can be operated alone or in parallel with one another. For example, in response to a state of mode control signals, the control circuit 222 can activate the final softmax circuit 224 and deactivate final log_softmax circuit 232, deactivate the final softmax circuit 224 and activate final log_softmax circuit 232, or activate both the final softmax circuit 224 and the final log_softmax circuit 232 to operate in parallel. The control circuit 222 can gate clock signals to the final softmax circuit 224 and final log_softmax circuit 232 to reduce power consumption when only one of the circuits is activated.
The formula of log(softmax):
The term, 2−dmax Σ2x
log(softmax(x))=(xi−dmax)−log(2−d
The input variable of the natural log function is a floating point number, which is represented in the form:
(−1)s*(1+M)*2E−E
where s is the sign bit, M is the mantissa, E is the exponent, which is shifted by a constant bias E0. The log function can be written as:
where My is in the interval to [0, 1), log2(1+My) could be calculated by polynomial fitting.
The final log_softmax circuit 232 is activated once the tensor-level dmax is available. A group of p+1 tensor elements (xt, t=0 . . . p) is read from buffer1 102 and input in parallel to the subtraction circuits 234. The subtraction circuits compute in parallel, differences between (xt−dmax for t=0 . . . p).
The mantissa (MSUM) of the SUM is input to the processor circuitry 212, which is configured to compute log2(1+MSUM). The exponent of the SUM (ESUM) is input to circuit 236, which converts ESUM to a floating point value. Adder 238 sums the values output from circuits 212 and 236 (float(ESUM)+log2(1+MSUM)), and the sum is input to processor circuitry 240. Processor circuitry 240 is configured to compute:
SUMlog=(float(ESUM)+log2(1+MSUM))/log2e
The processor circuitry 240 can be processor circuitry dedicated to computing SUMlog, or circuitry 212.
The SUMlog and p+1 differences (xt−dmax for t=0 . . . p) from subtraction circuits 234 are input to subtraction circuits 242. Subtraction circuits compute in parallel (xt−dmax)−SUMlog for t=0 . . . p, and the p+1 output terms are log (softmax (xt)).
The example shows the relative timing of operations involved in processing groups 0, 1, m, m+1, and i of i+1 groups of tensor elements (1<m<i). In time slot t0, p+1 tensor elements of a group are input to the circuit arrangement. In time slot t1, the tensor elements of group0 are multiplied by log2e, and in parallel therewith, the tensor elements of group1 are input. In time slot t2, the group-level bias, d0, is determined, along with polynomial fitting of the xj terms and differences between the xk terms and d0. Also in time slot t2, the tensor elements of group1 are multiplied by log2e. Though not shown, the tensor elements of group3 would be input in time slot t2.
In time slot t3, the differences and exponents of the 2(x_t)_j values computed from group0 are summed into the group-biased power-of-two elements and stored in association with the group-level bias d0. The group-biased power-of-two values for group0 are summed into a group-level sum (“SUMgroup0”). Also, in time slot t3, the group-level bias, d1, is determined, along with polynomial fitting of the xj terms and differences between the xk terms and d1.
In time slot t4, the group-level bias, d0, is compared to the current tensor-level bias, dmax, and the current tensor-level dmax is updated to the value of d0, since d0 is the first maximum computed. Also during time slot t4, the differences and exponents of the 2(x_t)_j values computed from group1 are summed into the group-biased power-of-two elements and stored in association with the group-level bias d1. The group-biased power-of-two values for groupo are summed into a group-level sum (“SUMgroup0”).
In time slot t5, the group-level sum is accumulated with the current SUM. The group-level sum SUMgroup_0 is aligned with the current accumulated SUM according the current dmax, and the aligned values are added to produces a new SUM. Also in time slot t5, the group-level bias, d1, is compared to the current tensor-level bias, dmax. If d1>dmax then the current tensor-level bias, dmax, is updated to the value of d1. Otherwise, dmax remains unchanged.
In time slot t6, the group-level sum is accumulated with the current SUM. The group-level sum SUMgroup_1 is aligned with the current accumulated SUM according the current dmax, and the aligned values are added to produces a new SUM.
The final groupi of tensor elements commences in time slot t0+i, and the processing is similar to that described above for time slots t0+i through t0+i t+5. In time slot t0+i+6, the final operations of softmax processing begin.
In time slot t0+i+6, the group-biased power-of-two elements of group0 and the associated group-level bias d0 are input, and ex_t*2−d_max values, “xt_dmax,” are computed for group0 as described above. Each ex_t*2−d_max is a floating point value having an exponent equal to xt_k−dmax+(the exponent bits of 2(x_t)_j), and a mantissa equal to the mantissa of 2(x_t)_j. In time slot t0+i+7, the p+1 softmax values of group0 are computed as (xt_dmax/SUM for t=0 . . . p)) and then output. Though not shown, the operations in time slots t0+i+6 and t0+i+7 would be performed for group1 . . . groupi, in ensuing time slots. For example, in time slot t0+i+7, the group-biased power-of-two elements of group1 and the associated group-level bias d1 are input, and ex_t*2−d_max values are computed for group1. In time slot t0+i+8, the the p+1 softmax values of group1 are computed as (xt_exp/SUM for t=0 . . . p)) and then output.
In time slot t0+i+5, the of p+1 tensor elements of groupo are input and parallel subtraction circuits compute differences between (xt−dmax for t=0 . . . p).
In timeslot t0+i+6, log2(1+MSUM) is computed from the mantissa (MSUM) of the SUM, and the exponent of the SUM (ESUM) is converted to a floating point value.
In timeslot t0+i+7, the log2(1+MSUM) and float(ESUM) values are summed.
In timeslot t0+i+8, the SUMlog term is computed from the log2(1+MSUM) and float(ESUM) values as:
(float(ESUM)+log2(1+MSUM))/log2e
In timeslot t0+i+9, the p+1 log(softmax) values of group0 are computed in parallel as (xt−dmax)−SUMlog for t=0 . . . p, and then output.
The operations are the same as those described in
In time slot t0+i+7, the p+1 softmax values of group0 are computed and then output. In parallel with the final softmax operation in time slot t0+i+7, the log2(1+MSUM) and float(ESUM) values are summed for log(softmax). The log(softmax) operations in time slots t0+i+8 and t0+i+9 are as described in
Referring to the PS 602, each of the processing units includes one or more central processing units (CPUs) and associated circuits, such as memories, interrupt controllers, direct memory access (DMA) controllers, memory management units (MMUs), floating point units (FPUs), and the like. The interconnect 616 includes various switches, busses, communication links, and the like configured to interconnect the processing units, as well as interconnect the other components in the PS 602 to the processing units.
The OCM 614 includes one or more RAM modules, which can be distributed throughout the PS 602. For example, the OCM 614 can include battery backed RAM (BBRAM), tightly coupled memory (TCM), and the like. The memory controller 610 can include a DRAM interface for accessing external DRAM. The peripherals 608, 615 can include one or more components that provide an interface to the PS 602. For example, the peripherals can include a graphics processing unit (GPU), a display interface (e.g., DisplayPort, high-definition multimedia interface (HDMI) port, etc.), universal serial bus (USB) ports, Ethernet ports, universal asynchronous transceiver (UART) ports, serial peripheral interface (SPI) ports, general purpose (GPIO) ports, serial advanced technology attachment (SATA) ports, PCIe ports, and the like. The peripherals 615 can be coupled to the MIO 613. The peripherals 608 can be coupled to the transceivers 607. The transceivers 607 can include serializer/deserializer (SERDES) circuits, MGTs, and the like.
Various logic may be implemented as circuitry to carry out one or more of the operations and activities described herein and/or shown in the figures. In these contexts, a circuit or circuitry may be referred to as “logic,” “module,” “engine,” or “block.” It should be understood that logic, modules, engines and blocks are all circuits that carry out one or more of the operations/activities. In certain implementations, a programmable circuit is one or more computer circuits programmed to execute a set (or sets) of instructions stored in a ROM or RAM and/or operate according to configuration data stored in a configuration memory.
Though aspects and features may in some cases be described in individual figures, it will be appreciated that features from one figure can be combined with features of another figure even though the combination is not explicitly shown or explicitly described as a combination.
The methods and circuits are thought to be applicable to a variety of systems that compute softmax and log(softmax) functions. Other aspects and features will be apparent to those skilled in the art from consideration of the specification. The methods and circuits may be implemented as one or more processors configured to execute software, as an application specific integrated circuit (ASIC), or as a logic on a programmable logic device. It is intended that the specification and drawings be considered as examples only, with a true scope of the invention being indicated by the following claims.