-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6b72590
commit ada9858
Showing
2 changed files
with
111 additions
and
0 deletions.
There are no files selected for viewing
59 changes: 59 additions & 0 deletions
59
...tencent/angel/graph/client/psf/universalembedding/UniversalEmbeddingExtraInitAsNodes.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package com.tencent.angel.graph.client.psf.universalembedding; | ||
|
||
import com.tencent.angel.graph.data.EmbeddingOrGrad; | ||
import com.tencent.angel.graph.data.UniversalEmbeddingNode; | ||
import com.tencent.angel.graph.utils.GraphMatrixUtils; | ||
import com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam; | ||
import com.tencent.angel.ml.matrix.psf.update.base.UpdateFunc; | ||
import com.tencent.angel.ps.storage.vector.ServerLongAnyRow; | ||
import com.tencent.angel.ps.storage.vector.element.IElement; | ||
import com.tencent.angel.psagent.matrix.transport.router.operator.ILongKeyAnyValuePartOp; | ||
|
||
/** | ||
* A PS function to initialize the universal embedding matrix on PS with extra embeddings | ||
*/ | ||
public class UniversalEmbeddingExtraInitAsNodes extends UpdateFunc { | ||
/** | ||
* Create a new UpdateParam | ||
*/ | ||
public UniversalEmbeddingExtraInitAsNodes(UniversalEmbeddingExtraInitParam param) { | ||
super(param); | ||
} | ||
|
||
public UniversalEmbeddingExtraInitAsNodes() { | ||
this(null); | ||
} | ||
|
||
@Override | ||
public void partitionUpdate(PartitionUpdateParam partParam) { | ||
PartUniversalEmbeddingExtraInitParam param = (PartUniversalEmbeddingExtraInitParam) partParam; | ||
ServerLongAnyRow row = GraphMatrixUtils.getPSLongKeyRow(psContext, param); | ||
ILongKeyAnyValuePartOp keyValuePart = (ILongKeyAnyValuePartOp) param.getKeyValuePart(); | ||
int dim = param.getDim(); | ||
int numSlots = param.getNumSlots(); | ||
long[] nodeIds = keyValuePart.getKeys(); | ||
IElement[] embeddings = keyValuePart.getValues(); | ||
row.startWrite(); | ||
try { | ||
for (int i = 0; i < nodeIds.length; i++) { | ||
UniversalEmbeddingNode embeddingNode = (UniversalEmbeddingNode) row.get(nodeIds[i]); | ||
if (embeddingNode == null) { | ||
if (numSlots <= 1) { | ||
embeddingNode = new UniversalEmbeddingNode(((EmbeddingOrGrad)embeddings[i]) | ||
.getValues()); | ||
} else { | ||
float[] slotsValues = new float[dim * (numSlots - 1)]; | ||
embeddingNode = new UniversalEmbeddingNode(((EmbeddingOrGrad)embeddings[i]) | ||
.getValues(), slotsValues, numSlots); | ||
} | ||
row.set(nodeIds[i], embeddingNode); | ||
} else { | ||
embeddingNode.setEmbeddings(((EmbeddingOrGrad)embeddings[i]) | ||
.getValues()); | ||
} | ||
} | ||
} finally { | ||
row.endWrite(); | ||
} | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
...m/tencent/angel/graph/client/psf/universalembedding/UniversalEmbeddingExtraInitParam.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package com.tencent.angel.graph.client.psf.universalembedding; | ||
|
||
import com.tencent.angel.PartitionKey; | ||
import com.tencent.angel.ml.matrix.MatrixMeta; | ||
import com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam; | ||
import com.tencent.angel.ml.matrix.psf.update.base.UpdateParam; | ||
import com.tencent.angel.ps.storage.vector.element.IElement; | ||
import com.tencent.angel.psagent.PSAgentContext; | ||
import com.tencent.angel.psagent.matrix.transport.router.KeyValuePart; | ||
import com.tencent.angel.psagent.matrix.transport.router.RouterUtils; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
public class UniversalEmbeddingExtraInitParam extends UpdateParam { | ||
|
||
private long[] nodeIds; | ||
|
||
private IElement[] embeddings; | ||
|
||
private int dim; | ||
|
||
private int numSlots; | ||
|
||
public UniversalEmbeddingExtraInitParam(int matrixId, long[] nodeIds, IElement[] embeddings, | ||
int dim, int numSlots) { | ||
super(matrixId); | ||
this.nodeIds = nodeIds; | ||
this.embeddings = embeddings; | ||
this.dim = dim; | ||
this.numSlots = numSlots; | ||
} | ||
|
||
@Override | ||
public List<PartitionUpdateParam> split() { | ||
MatrixMeta meta = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(matrixId); | ||
PartitionKey[] parts = meta.getPartitionKeys(); | ||
|
||
KeyValuePart[] splits = RouterUtils.split(meta, 0, nodeIds, embeddings); | ||
assert parts.length == splits.length; | ||
|
||
List<PartitionUpdateParam> partParams = new ArrayList<>(parts.length); | ||
for(int i = 0; i < parts.length; i++) { | ||
if(splits[i] != null && splits[i].size() > 0) { | ||
partParams.add(new PartUniversalEmbeddingExtraInitParam(matrixId, parts[i], splits[i], | ||
dim, numSlots)); | ||
} | ||
} | ||
|
||
return partParams; | ||
} | ||
} |