Skip to content

Commit

Permalink
Remove base walks & Support multiple schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
cenyk1230 committed Dec 16, 2019
1 parent 52b4ecd commit 19d2194
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
17 changes: 6 additions & 11 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def parse_args():
parser.add_argument('--eval-type', type=str, default='all',
help='The edge type(s) for evaluation.')

parser.add_argument('--schema', default=None,
help='The metapath schema (e.g., U-I-U).')
parser.add_argument('--schema', type=str, default=None,
help='The metapath schema (e.g., U-I-U,I-U-I).')

parser.add_argument('--dimensions', type=int, default=200,
help='Number of dimensions. Default is 200.')
Expand Down Expand Up @@ -159,16 +159,11 @@ def get_batches(pairs, batch_size):
return result

def generate_walks(network_data):
base_network = network_data['Base']

if args.schema is not None:
node_type = load_node_type(file_name + '/node_type.txt')
else:
node_type = None

base_walker = RWGraph(get_G_from_edges(base_network), node_type=node_type)
base_walks = base_walker.simulate_walks(args.num_walks, args.walk_length, schema=args.schema)

all_walks = []
for layer_id in network_data:
if layer_id == 'Base':
Expand All @@ -178,19 +173,19 @@ def generate_walks(network_data):
# start to do the random walk on a layer

layer_walker = RWGraph(get_G_from_edges(tmp_data))
layer_walks = layer_walker.simulate_walks(args.num_walks, args.walk_length)
layer_walks = layer_walker.simulate_walks(args.num_walks, args.walk_length, schema=args.schema)

all_walks.append(layer_walks)

print('finish generating the walks')

return base_walks, all_walks
return all_walks


def train_model(network_data, feature_dic, log_name):
base_walks, all_walks = generate_walks(network_data)
all_walks = generate_walks(network_data)

vocab, index2word = generate_vocab([base_walks])
vocab, index2word = generate_vocab(all_walks)

train_pairs = generate_pairs(all_walks, vocab)

Expand Down
10 changes: 8 additions & 2 deletions src/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,16 @@ def simulate_walks(self, num_walks, walk_length, schema=None):
walks = []
nodes = list(G.nodes())
# print('Walk iteration:')
if schema is not None:
schema_list = schema.split(',')
for walk_iter in range(num_walks):
random.shuffle(nodes)
for node in nodes:
if schema == None or schema.split('-')[0] == self.node_type[node]:
walks.append(self.walk(walk_length=walk_length, start=node, schema=schema))
if schema is None:
walks.append(self.walk(walk_length=walk_length, start=node))
else:
for schema_iter in schema_list:
if schema_iter.split('-')[0] == self.node_type[node]:
walks.append(self.walk(walk_length=walk_length, start=node, schema=schema_iter))

return walks

0 comments on commit 19d2194

Please sign in to comment.