Skip to content

Commit

Permalink
universalEmbedding extra初始化类添加
Browse files Browse the repository at this point in the history
  • Loading branch information
lengfeng343 committed Nov 24, 2022
1 parent 6b72590 commit ada9858
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.tencent.angel.graph.client.psf.universalembedding;

import com.tencent.angel.graph.data.EmbeddingOrGrad;
import com.tencent.angel.graph.data.UniversalEmbeddingNode;
import com.tencent.angel.graph.utils.GraphMatrixUtils;
import com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam;
import com.tencent.angel.ml.matrix.psf.update.base.UpdateFunc;
import com.tencent.angel.ps.storage.vector.ServerLongAnyRow;
import com.tencent.angel.ps.storage.vector.element.IElement;
import com.tencent.angel.psagent.matrix.transport.router.operator.ILongKeyAnyValuePartOp;

/**
* A PS function to initialize the universal embedding matrix on PS with extra embeddings
*/
public class UniversalEmbeddingExtraInitAsNodes extends UpdateFunc {
/**
* Create a new UpdateParam
*/
public UniversalEmbeddingExtraInitAsNodes(UniversalEmbeddingExtraInitParam param) {
super(param);
}

public UniversalEmbeddingExtraInitAsNodes() {
this(null);
}

@Override
public void partitionUpdate(PartitionUpdateParam partParam) {
PartUniversalEmbeddingExtraInitParam param = (PartUniversalEmbeddingExtraInitParam) partParam;
ServerLongAnyRow row = GraphMatrixUtils.getPSLongKeyRow(psContext, param);
ILongKeyAnyValuePartOp keyValuePart = (ILongKeyAnyValuePartOp) param.getKeyValuePart();
int dim = param.getDim();
int numSlots = param.getNumSlots();
long[] nodeIds = keyValuePart.getKeys();
IElement[] embeddings = keyValuePart.getValues();
row.startWrite();
try {
for (int i = 0; i < nodeIds.length; i++) {
UniversalEmbeddingNode embeddingNode = (UniversalEmbeddingNode) row.get(nodeIds[i]);
if (embeddingNode == null) {
if (numSlots <= 1) {
embeddingNode = new UniversalEmbeddingNode(((EmbeddingOrGrad)embeddings[i])
.getValues());
} else {
float[] slotsValues = new float[dim * (numSlots - 1)];
embeddingNode = new UniversalEmbeddingNode(((EmbeddingOrGrad)embeddings[i])
.getValues(), slotsValues, numSlots);
}
row.set(nodeIds[i], embeddingNode);
} else {
embeddingNode.setEmbeddings(((EmbeddingOrGrad)embeddings[i])
.getValues());
}
}
} finally {
row.endWrite();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.tencent.angel.graph.client.psf.universalembedding;

import com.tencent.angel.PartitionKey;
import com.tencent.angel.ml.matrix.MatrixMeta;
import com.tencent.angel.ml.matrix.psf.update.base.PartitionUpdateParam;
import com.tencent.angel.ml.matrix.psf.update.base.UpdateParam;
import com.tencent.angel.ps.storage.vector.element.IElement;
import com.tencent.angel.psagent.PSAgentContext;
import com.tencent.angel.psagent.matrix.transport.router.KeyValuePart;
import com.tencent.angel.psagent.matrix.transport.router.RouterUtils;

import java.util.ArrayList;
import java.util.List;

public class UniversalEmbeddingExtraInitParam extends UpdateParam {

private long[] nodeIds;

private IElement[] embeddings;

private int dim;

private int numSlots;

public UniversalEmbeddingExtraInitParam(int matrixId, long[] nodeIds, IElement[] embeddings,
int dim, int numSlots) {
super(matrixId);
this.nodeIds = nodeIds;
this.embeddings = embeddings;
this.dim = dim;
this.numSlots = numSlots;
}

@Override
public List<PartitionUpdateParam> split() {
MatrixMeta meta = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(matrixId);
PartitionKey[] parts = meta.getPartitionKeys();

KeyValuePart[] splits = RouterUtils.split(meta, 0, nodeIds, embeddings);
assert parts.length == splits.length;

List<PartitionUpdateParam> partParams = new ArrayList<>(parts.length);
for(int i = 0; i < parts.length; i++) {
if(splits[i] != null && splits[i].size() > 0) {
partParams.add(new PartUniversalEmbeddingExtraInitParam(matrixId, parts[i], splits[i],
dim, numSlots));
}
}

return partParams;
}
}

0 comments on commit ada9858

Please sign in to comment.